加载中…
个人资料
  • 博客等级:
  • 博客积分:
  • 博客访问:
  • 关注人气:
  • 获赠金笔:0支
  • 赠出金笔:0支
  • 荣誉徽章:
正文 字体大小:

spark中ALS算法详解

(2019-05-13 22:52:56)
标签:

spark

als

详解

分类: 机器学习
spark中ALS算法详解

 本文参考:

1、https://blog.csdn.net/chencheng12077/article/details/52954703

2、https://blog.csdn.net/u011239443/article/details/51752904

以上两篇文章写的非常好,其中第一篇对代码的解读更加详细,第二篇对宏观层面把握更加到位。

一、ALS的核心思路

    打分矩阵A是近似低秩的,即一个spark中ALS算法详解的打分矩阵A可以用两个小矩阵spark中ALS算法详解spark中ALS算法详解的乘积来近似:spark中ALS算法详解

    我们使用用户喜好特征矩阵spark中ALS算法详解中第i各用户的特征向量spark中ALS算法详解,和产品特征矩阵spark中ALS算法详解第j各产品的特征向量spark中ALS算法详解来预测打分矩阵spark中ALS算法详解中的spark中ALS算法详解。我们可以得出矩阵分解模型的损失函数为:
spark中ALS算法详解

    spark使用交叉最小二乘法(ALS)来最优化损失函数。算法的思想就是:我们先随机生成u和v,然后固定u求解v,再固定v求解u,这样交替进行下去,直到取得最优解min(C)。
    我们将U和V拆成InBlock和OutBlock两部分,每次计算时使用A的OutBlock和B的InBlock来更新B的InBlock
的值,其中A和B表示U和V或者V和U的关系。

二、算法步骤
1、准备工作
(1)构建userPart和itemPart,本次计算以分区为单位进行计算。
(2)构造userLocalIndexEncoder和itemLocalIndexEncoder,这个变量用来存放block信息和本地索引信息。

2、PartitionRatings操作
(1)根据srcPart(userPart).getPartition(userId)得到srcBlockId,根据dstPart(itemPart).getPartition(itemId)得到dstBlockId。得到的关系如下:
spark中ALS算法详解

如果userId和itemId有关联,则其对应的srcBlockId和dstBlockId必定有关联。
(2)将同一个(srcBlockId, dstBlockId)的打分信息(userId, itemId, rating)拼凑在一起放到RatingBlock的结构下,RatingBlock中装有(userIds, itemIds, ratings)的数据,最终返回((srcBlockId, dstBlockId), RatingBlock)的结构。

3、构建InBlock的操作
针对每一个(srcBlockId, dstBlockId)
(1)将RatingBlock中的dstIds去重后再排序,得到(dstId, index)的映射;
(2)将RatingBlock中每个dstId映射成index,得到dstLocalIndices,此时值的形式为(srcBlockId, (dstBlockId, srcIds, dstLocalIndices, ratings))。
(3)将dstBlockId + dstLocalIndices得到dstEncodedIndices,此时值的形式为(srcBlockId, (srcIds, dstEncodedIndices, ratings))。
(4)将上一步的值转化为类CSC格式,将值中的srcId排序,其他的dstEncodedIndices和ratings中元素也做相应调整。接着将srcIds中元素去重,并记录每个srcId元素的作用范围,得到dstPtrs。此时值的形式为(srcBlock, InBlock(uniqueSrcIds, dstPtrs, dstEncodedIndices, ratings))。

4、构建OutBlock的操作
计算srcBlockId对应的dstBlockId对应的srcId信息,即:
spark中ALS算法详解
其中srcId1/srcId2和srcId3/srcId4可能存在重复,最终返回(srcBlockId, activeIds)。

5、根据第3、4步得到(userInBlocks, userOutBlocks)和(itemInBlocks, itemOutBlocks)。

6、分别根据userInBlocks和itemInBlocks初始化userFactors和itemFactors,userFactor的格式为(srcBlockId, factors),factors的格式又是Array(srcId, factor)。

7、交替计算userFactors和itemFactors,先讲从userFactors计算itemFactors.
(1)通过srcOutBlocks和srcFactorBlocks(上次生成的factor值)得到每个dstBlockId有链接的每个srcBlockId的每个srcBlockId的每个srcId对应的factor值,如下所示:
spark中ALS算法详解

继续对dstBlockId进行聚合操作,得到如下关系:
spark中ALS算法详解
定义上述结构为merged  RDD。
(2)根据InBlocks和merged进行join,此时InBlocks为(dstBlockId, InBlock(uniqueDstIds, srcPtrs, srcEncodedIndices, ratings)),merged为(dstBlockId, srcBlockId * localIndex * factor三维数组)。
根据InBlocks中srcEncodedIndices可以解析出srcBlockId和srcLocalIndex,从而获取获取srcEncodedIndices中每个元素的factor。InBlocks中srcPtrs划定了每个uniqueDstId对应的srcId的个数,取对应的所有srcId的factor的值和InBlocks中对应的所有dstId的rating值,计算出每个dstId的factor值。
(3)同理可以从itemFactors计算出userFactors。

8、求解正规方程的方式更新factor值
(1)正规方程。假设已知userFactor和rating,求itemFactor,满足userFactor * itemFactor = rating。
拟合上述方程需要使用正规方程,得到的解形如:spark中ALS算法详解,其中spark中ALS算法详解为ata,spark中ALS算法详解为atb。在第7步的同一个dstId中添加数据,构成正规方程,并求解这个方程。
(2)add操作时构建上三角矩阵,ata = da * da + ata
atb = b * da + atb
(3)更新factor的方法solve时步骤如下:
首先计算残差res = ata * X - atb
其次将res赋值给dir,类似梯度值
然后计算步长step = dir * res / ( ata * dir * dir + min)
接着遍历dir,若step * dir(i) > x(i),则取step=x(i) / dir(i)
最后是迈步操作,x(i) = x(i) - step * dir(i)

三、代码(代码不规范的,只能用来参考)
1、测试文件BlasTest.scala
package com.pr.fortest.als

import org.apache.log4j.{Level, Logger}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.{HashPartitioner, SparkConf, SparkContext}

object BlasTest {

  def main(args: Array[String]): Unit = {
    // 屏蔽日志
    Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
    Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)

    // 设置运行环境
    val conf = new SparkConf().setAppName("ALSTest").setMaster("local[2]")
    val sc = new SparkContext(conf)

    val ratingsList = List((1, 2, 1.0), (176, 3, 2.0), (177, 3, 2.0),(178, 3, 2.0),(179, 3, 2.0),(175, 3, 2.0),(200, 1, 1.0), (25, 3, 3.0))
    val rank = 10
    val maxIter = 11

    val ratings = sc.parallelize(ratingsList).map(x => Rating(x._1, x._2, x._3))
    val userPart = new HashPartitioner(2)
    val itemPart = new HashPartitioner(2)
    val userLocalIndexEncoder = new LocalIndexEncoder2(userPart.numPartitions)
    val itemLocalIndexEncoder = new LocalIndexEncoder2(itemPart.numPartitions)

    val partitionRatings = new PartitionRatings()
    val blockRatings = partitionRatings.partitionRatings(ratings, userPart, itemPart)

    val (userInBlocks, userOutBlocks) = partitionRatings.makeBlocks("user", blockRatings, userPart, itemPart, StorageLevel.MEMORY_AND_DISK)
    userOutBlocks.count()

    val swappedBlockRatings = blockRatings.map{
      case ((userBlockId, itemBlockId), RatingBlock(userIds, itemIds, localRatings)) =>
        ((itemBlockId, userBlockId), RatingBlock(itemIds, userIds, localRatings))
    }

    val (itemInBlocks, itemOutBlocks) = partitionRatings.makeBlocks("item", swappedBlockRatings, itemPart, userPart, StorageLevel.MEMORY_AND_DISK)
    itemOutBlocks.count()

    val factorCalculator = new FactorCalculator()
    var userFactors = factorCalculator.initialize(userInBlocks, rank)
    var itemFactors = factorCalculator.initialize(itemInBlocks, rank)

    for(iter <- 0 to maxIter){
      itemFactors = factorCalculator.computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, 0.01, userLocalIndexEncoder)
      userFactors = factorCalculator.computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, 0.01, itemLocalIndexEncoder)
    }

    itemFactors.take(10).foreach{case(idx, arr_arr) =>
      println(idx)
      println(arr_arr(0).mkString(":"))
      println("-------------")
    }

    sc.stop()

  }
}

2、FactorCalculator.scala,主要是计算因子相关的操作
package com.pr.fortest.als

import org.apache.spark.rdd.RDD
import scala.util.Random
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.HashPartitioner

// 计算因子
class FactorCalculator {

  // 初始化每个因子
  def initialize(inBlocks: RDD[(Int, InBlock)], rank: Int): RDD[(Int,  Array[Array[Float]])] = {
    inBlocks.map{case(srcBlockId, inBlock) =>
        val factors = Array.fill(inBlock.srcIds.length){
          val factor = Array.fill(rank)(Random.nextGaussian().toFloat)
          val nrm = blas.snrm2(rank, factor, 1)
          blas.sscal(rank, 1.0f/nrm, factor, 1)  // 归一化操作
          factor
        }
      (srcBlockId, factors)
    }
  }

  def computeFactors(  // 以用户因子为例
                    srcFactorBlocks: RDD[(Int, Array[Array[Float]])],  // 用户分块id,以及这个分块下每个用户对应的阶长数组
                    srcOutBlocks: RDD[(Int, Array[Array[Int]])], // 用户分块id,和这个分块有链接的各产品分块对应的用户本地索引数组
                    dstInBlocks: RDD[(Int, InBlock)], // 产品分块,以及唯一产品id数组/唯一产品id数组对应用户的累计条数数组/用户编码数组/产品对应用户得分
                    rank: Int,
                    regParam: Double,
                    srcEncoder: LocalIndexEncoder2): RDD[(Int, Array[Array[Float]])] = {
    val numSrcBlocks = srcFactorBlocks.partitions.length
    val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap{
      case(srcBlockId, (srcOutBlock, srcFactors)) =>
        srcOutBlock.zipWithIndex.map{case(activeIndices, dstBlockId) => // 搞出产品blockid
          (dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx))))
        }
    } // 返回key是产品分块id,value的key是用户分块id,value的value是这个产品对应用户分块里每个产品分块id对应的用户因子数组

    val merged = srcOut.groupByKey(new HashPartitioner(dstInBlocks.partitions.length))
    dstInBlocks.join(merged).mapValues{
      case(InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) =>
        val sortedSrcFactors = new Array[Array[Array[Float]]](numSrcBlocks)
        srcFactors.foreach{ case(srcBlockId, factors) =>
          sortedSrcFactors(srcBlockId) = factors
        }
        val dstFactors = new Array[Array[Float]](dstIds.length)
        var j = 0
        val ls = new NormalEquation2(rank)
        while(j < dstIds.length){
          ls.reset()
          var i = srcPtrs(j)
          var numExplicits = 0
          while(i < srcPtrs(j + 1)){
            val encoded = srcEncodedIndices(i)
            val blockId = srcEncoder.blockId(encoded)
            val localIndex = srcEncoder.localIndex(encoded)
            val srcFactor = sortedSrcFactors(blockId)(localIndex)
            val rating = ratings(i)
            ls.add(srcFactor, rating)
            numExplicits += 1
            i += 1
          }
          dstFactors(j) = NNLS2.solve(rank, ls.ata, ls.atb, 0.01)
          j += 1
        }
        dstFactors
    }
  }

}

3、LocalIndexEncoder2.scala,用来将分区+本地索引信息放入一个Int中表示
package com.pr.fortest.als

class LocalIndexEncoder2(numBlocks: Int) extends Serializable {
  val numLocalIndexBits = math.min(, 31)  // 取非0前面0的个数,也就是localIndex占用区间的位数
  val localIndexMask = (1 << numLocalIndexBits) - 1 // 掩码

  def getIndexBit(): Int = numLocalIndexBits
  def getLocalIndexMask(): String = localIndexMask.toBinaryString

  // 将(blockId, localIndex)存放到一个整数里面
  def encode(blockId: Int, localIndex: Int): Int = {
    require(blockId < numBlocks)
    require((localIndex & ~localIndexMask) == 0)  // blockId存放区间和localIndex存放区间不能有交集, 也就是localIndex不能太大跑到非localIndex的区间去了
    (blockId << numLocalIndexBits) | localIndex
  }

  def blockId(encoded: Int): Int = {
    encoded >>> numLocalIndexBits  // 注意是>>> 而非>>,使用不带符号的位移方式
  }

  def localIndex(encoded: Int): Int = {
    encoded & localIndexMask
  }

}

4、NNLS2.scala,解决非负最小平方问题
package com.pr.fortest.als

import java.util
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import scala.util.Random

// 用修改过的梯度映射法解决非负最小平方问题
object NNLS2 {
  private var ata: Array[Double] = _
  private var rank: Int = -1
  private var workspace: Workspace = _

  // 跟计算相关的中间变量
  class Workspace(val n: Int){
    val scratch = new Array[Double](n)
    val grad = new Array[Double](n)
    val x = new Array[Double](n)
    val dir = new Array[Double](n)
    val lastDir = new Array[Double](n)
    val res = new Array[Double](n)

    def wipe(): Unit = {
      util.Arrays.fill(scratch, 0.0)
      util.Arrays.fill(grad, 0.0)
      util.Arrays.fill(x, 0.0)
      util.Arrays.fill(dir, 0.0)
      util.Arrays.fill(lastDir, 0.0)
      util.Arrays.fill(res, 0.0)
    }
  }

  def createWorkspace(n: Int): Workspace = {
    new Workspace(n)
  }

  // 初始化中间变量
  def initialize(rank: Int): Unit = {
    this.rank = rank
    workspace = createWorkspace(rank)
    ata = new Array[Double](rank * rank)
  }

  // 填充ata矩阵,之前只是上三角矩阵的值
  def fillAtA(triAtA: Array[Double], lambda: Double) {
    var i = 0
    var pos = 0
    var a = 0.0
    while (i < rank) {
      var j = 0
      while (j <= i) {
        a = triAtA(pos)
        ata(i * rank + j) = a
        ata(j * rank + i) = a
        pos += 1
        j += 1
      }
      ata(i * rank + i) += lambda
      i += 1
    }
  }

  // solve主入口函数
  def solve(k: Int, ata1: Array[Double], atb1: Array[Double], lambda: Double): Array[Float] = {
    val rank = k
    initialize(rank)
    fillAtA(ata1, lambda)
    val x = solve(ata, atb1, workspace)
    x.map(x => x.toFloat)
  }

  // 使用梯度下降的方法计算因子,中间比较复杂的是step的更新
  def solve(ata: Array[Double], atb: Array[Double], ws: Workspace): Array[Double] = {
    ws.wipe()

    val n = atb.length
    val scratch = ws.scratch


    // 计算步长的公式,dir:梯度方向, res: 残差
    def steplen(dir: Array[Double], res: Array[Double]): Double = {
      val top = blas.ddot(n, dir, 1, res, 1)
      blas.dgemv("N", n, n, 1.0, ata, n, dir, 1, 0.0, scratch, 1)
      top / (blas().ddot(n, scratch, 1, dir, 1) + 1e-20)
    }

    def stop(step: Double, ndir: Double, nx: Double): Boolean = {
      ((step.isNaN)
        || (step < 1e-7)
        || (step > 1e40)
        || (ndir < 1e-12 * nx)
        || (ndir < 1e-32)
        )
    }

    val grad = ws.grad
    val x = ws.x
    val dir = ws.dir
    val lastDir = ws.lastDir
    val res = ws.res
    val iterMax = math.max(400, 20 * n)
    var lastNorm = 0.0
    var iterno = 0
    var lastWall = 0
    var i = 0
    while(iterno < iterMax){
      blas.dgemv("N", n, n, 1.0, ata, n, x, 1, 0.0, res, 1)
      blas.daxpy(n, -1.0, atb, 1, res, 1)
      blas.dcopy(n, res, 1, grad, 1)

      i = 0
      while(i < n){
        if(grad(i) > 0.0 && x(i) == 0.0){
          grad(i) = 0.0
        }
        i = i + 1
      }

      val ngrad = blas.ddot(n, grad, 1, grad, 1)
      blas.dcopy(n, grad, 1, dir, 1)

      var step = steplen(grad, res)
      var ndir = 0.0
      val nx = blas.ddot(n, x, 1, x, 1)
      if(iterno > lastWall + 1){
        val alpha = ngrad / lastNorm
        blas.daxpy(n, alpha, lastDir, 1, dir, 1)
        val dstep = steplen(dir, res)
        ndir = blas.ddot(n, dir, 1, dir, 1)
        if (stop(dstep, ndir, nx)) {
          // reject the CG step if it could lead to premature termination
          blas.dcopy(n, grad, 1, dir, 1)
          ndir = blas.ddot(n, dir, 1, dir, 1)
        } else {
          step = dstep
        }
      }else{
        ndir = blas.ddot(n, dir, 1, dir, 1)
      }

      if(stop(step, ndir, nx)){
        return x.clone
      }

      i = 0
      while(i < n){
        if(step * dir(i) > x(i)){
          step = x(i) / dir(i)
        }
        i = i + 1
      }

      i = 0
      // 迈步操作
      while(i < n){
        if(step * dir(i) > x(i) * (1 - 1e-14)){
          x(i) = 0
          lastWall = iterno
        }else{
          x(i) -= step * dir(i)
        }
        i = i + 1
      }

      iterno = iterno + 1
      blas.dcopy(n, dir, 1, lastDir, 1)
      lastNorm = ngrad
    }
    x.clone
  }

  // 测试
  def main(args: Array[String]) = {
    val n = 10
    val lambda = 0.02

    val ata = Array.fill(n*(n+1)/2)(Random.nextGaussian())
    val atb = Array.fill(n)(Random.nextGaussian())
    println("[ata]" + ata.mkString(":"))
    println("[atb]" + atb.mkString(":"))

    val result = solve(n, ata, atb, lambda).mkString(":")
    println(result)
  }

}

5、NormalEquation2.scala,正规方程类。
package com.pr.fortest.als

import com.github.fommil.netlib.BLAS.{getInstance => blas}

class NormalEquation2(val k: Int) extends Serializable {
  val triK = k * (k + 1) / 2
  val ata = new Array[Double](triK)
  val atb = new Array[Double](k)

  val da = new Array[Double](k)
  val upper = "U"

  def copyToDouble(a: Array[Float]) = {
    var i = 0
    while(i < k){
      da(i) = a(i)
      i += 1
    }
  }

  def add(a: Array[Float], b: Double, c: Double=1.0): this.type = {
    require(c >= 0.0)
    require(a.length == k)
    copyToDouble(a)
    blas.dspr(upper, k, c, da, 1, ata)
    if(b != 0.0){
      blas.daxpy(k, b, da, 1, atb, 1)
    }
    this
  }

  def merge(other: NormalEquation2): this.type = {
    require(other.k == k)
    blas.daxpy(ata.length, 1.0, other.ata, 1, ata, 1)
    blas.daxpy(atb.length, 1.0, other.atb, 1, atb, 1)
    this
  }

  def reset(): Unit = {
   
   
  }

}


6、PartitionRatings.scala,主要是计算InBlock和OutBlock
package com.pr.fortest.als

import org.apache.spark.{HashPartitioner, Partitioner}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.collection.OpenHashSet

import scala.collection.mutable
import scala.util.Sorting

// 原始的用户打分类
case class Rating(user: Int, item: Int, rating: Double)

// 一个rating block存放多个src IDs, dst IDs以及ratings
case class RatingBlock(srcIds: Array[Int], dstIds: Array[Int], ratings: Array[Double]){
  def size: Int = srcIds.length
  require(dstIds.length == srcIds.length)
  require(ratings.length == srcIds.length)
}

// 构建RatingBlock类
class RatingBlockBuilder extends Serializable{
  private val  srcIds = mutable.ArrayBuilder.make[Int]
  private val dstIds = mutable.ArrayBuilder.make[Int]
  private val ratings = mutable.ArrayBuilder.make[Double]

  var size = 0

  // 增加一个rating信息
  def add(r: Rating): this.type = {
    size += 1
    srcIds += r.user
    dstIds += r.item
    ratings += r.rating
    this
  }

  // 合并另一个RatingBlockBuilder
  def merge(other: RatingBlock): this.type = {
    size += other.srcIds.length
    srcIds ++= other.srcIds
    dstIds ++= other.dstIds
    ratings ++= other.ratings
    this
  }

  // 构建一个RatingBlock
  def build(): RatingBlock = {
    RatingBlock(srcIds.result(), dstIds.result(), ratings.result())
  }

}

// 将原始打分放入blocks中
// 返回格式:((srcBlockId, dstBlockId), ratingBlock)
class PartitionRatings {

  // 将原始打分数据放入blocks中
  // 返回:((srcBlockId, dstBlockId), ratingBlock)
  def partitionRatings(ratings: RDD[Rating], srcPart: Partitioner, dstPart: Partitioner): RDD[((Int, Int), RatingBlock)] = {
    val numPartitions = srcPart.numPartitions * dstPart.numPartitions
    ratings.mapPartitions{iter =>
      val builders = Array.fill(numPartitions)(new RatingBlockBuilder)
      iter.flatMap{ r =>
        val srcBlockId = srcPart.getPartition(r.user) // 根据value生成一个partitionId
        val dstBlockId = dstPart.getPartition(r.item)
        val idx = srcBlockId + dstBlockId * srcPart.numPartitions
        val builder = builders(idx)
        builder.add(r)
        if(builder.size >= 2048){  // builder数据太大则返回
          builders(idx) = new RatingBlockBuilder
          Iterator.single(((srcBlockId, dstBlockId), builder.build()))
        }else {
          Iterator.empty
        }
      } ++ { // 没返回部分的builder统一返回
        builders.zipWithIndex.filter(_._1.size > 0).map{case(block, idx) =>
          val srcBlockId = idx % srcPart.numPartitions
          val dstBlockId = idx / srcPart.numPartitions
          ((srcBlockId, dstBlockId), block.build())
        }
      }
    }.groupByKey().mapValues{blocks =>
      val builder = new RatingBlockBuilder
      blocks.foreach(builder.merge)
      builder.build()
    }.setName("ratingBlocks")
  }

  // 根据RatingBlock创建in-blocks和out-blocks
  def makeBlocks(
                prefix: String,
                ratingBlocks: RDD[((Int, Int), RatingBlock)],
                srcPart: Partitioner,
                dstPart: Partitioner,
                storageLevel: StorageLevel): (RDD[(Int, InBlock)], RDD[(Int, Array[Array[Int]])]) = {
    val inBlocks = ratingBlocks.map{
      case((srcBlockId, dstBlockId), RatingBlock(srcIds, dstIds, ratings)) =>
        val dstIdSet = new OpenHashSet[Int](1 << 20)
        dstIds.foreach(dstIdSet.add)
        val sortedDstIds = new Array[Int](dstIdSet.size)
        var i = 0
        var pos = dstIdSet.nextPos(0)
        while(pos != -1){
          sortedDstIds(i) = dstIdSet.getValue(pos)
          pos = dstIdSet.nextPos(pos + 1)
          i += 1
        }
        Sorting.quickSort(sortedDstIds)  // dstIds去重后排序
        val dstIdToLocalIndex = new mutable.OpenHashMap[Int, Int](sortedDstIds.length)  // dstId值对应的索引关系
        i = 0
        while(i < sortedDstIds.length){
          dstIdToLocalIndex.update(sortedDstIds(i), i)
          i += 1
        }
        val dstLocalIndices = dstIds.map(dstIdToLocalIndex.apply) // 将dstIds值转化为索引值
        (srcBlockId, (dstBlockId, srcIds, dstLocalIndices, ratings))
    }.groupByKey(new HashPartitioner(srcPart.numPartitions))
      .mapValues{iter =>
        val builder = new UncompressedInBlockBuilder(new LocalIndexEncoder2(dstPart.numPartitions))
        iter.foreach{case(dstBlockId, srcIds, dstLocalIndices, ratings) =>
            builder.add(dstBlockId, srcIds, dstLocalIndices, ratings)
        }
        builder.build().compress()
      }.setName(prefix + "InBlocks")
      .persist(storageLevel)
    val outBlocks = inBlocks.mapValues{ case InBlock(srcIds, dstPtrs, dstEncodedIndices, _) =>
        val encoder = new LocalIndexEncoder2(dstPart.numPartitions)
        val activeIds = Array.fill(dstPart.numPartitions)(mutable.ArrayBuilder.make[Int])
        var i = 0
        val seen = new Array[Boolean](dstPart.numPartitions)

        while(i < srcIds.length){
          var j = dstPtrs(i)
          java.util.Arrays.fill(seen, false)
          while(j < dstPtrs(i + 1)){
            val dstBlockId = encoder.blockId(dstEncodedIndices(j))
            if(!seen(dstBlockId)){  // seen的作用是标记某个用户是否与某个产品block相连,true就是相连,false不相连
              activeIds(dstBlockId) += i
              seen(dstBlockId) = true
            }
            j += 1
          }
          i += 1
        }
        activeIds.map{x =>
          val result = x.result()
          println("------")
          result.foreach(println)
          println("------")
          result
        // 最终返回二维矩阵,行是dstBlockId,列是和这个dstBlockId有关联的用户索引数组
    }.setName(prefix + "OutBlocks")
      .persist(storageLevel)

    (inBlocks, outBlocks)
  }

}

// 构建无压缩的in-block的构建类,格式是(srcId, dstEncodedIndex, rating)
class UncompressedInBlockBuilder(encoder: LocalIndexEncoder2){
  private val srcIds = mutable.ArrayBuilder.make[Int]
  private val dstEncodedIndices = mutable.ArrayBuilder.make[Int]  // (blockid, dstLocalIndices)组成的整数
  private val ratings = mutable.ArrayBuilder.make[Double]

  def add(dstBlockId: Int, srcIds: Array[Int], dstLocalIndices: Array[Int], ratings: Array[Double]): this.type = {
    val sz = srcIds.length
    require(dstLocalIndices.length == sz)
    require(ratings.length == sz)

    this.srcIds ++= srcIds
    this.ratings ++= ratings
    var j = 0
    while(j < sz){
      this.dstEncodedIndices += encoder.encode(dstBlockId, dstLocalIndices(j))
      j += 1
    }
    this
  }

  def build(): UncompressedInBlock = {
    new UncompressedInBlock(srcIds.result(), dstEncodedIndices.result(), ratings.result())
  }
}

// 一个blockId,包含(srcIds, dstEncodedIndices, ratings)
class UncompressedInBlock(val srcIds: Array[Int], val dstEncodedIndices: Array[Int], val ratings: Array[Double]){
  def length: Int = srcIds.length

  def compress() = {
    sort()
    val sz = length
    val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[Int]
    val dstCountsBuilder = mutable.ArrayBuilder.make[Int]
    var preSrcId = srcIds(0)
    uniqueSrcIdsBuilder += preSrcId
    var curCount = 1
    var i = 1
    while(i < sz){
      val srcId = srcIds(i)
      if(srcId != preSrcId){
        uniqueSrcIdsBuilder += srcId
        dstCountsBuilder += curCount
        preSrcId = srcId
        curCount = 0
      }
      curCount += 1
      i += 1
    }
    dstCountsBuilder += curCount

    val uniqueSrcIds = uniqueSrcIdsBuilder.result()
    val numUniqueSrcIds = uniqueSrcIds.length
    val dstCounts = dstCountsBuilder.result()
    val dstPtrs = new Array[Int](numUniqueSrcIds + 1)
    var sum = 0
    i = 0
    while(i < numUniqueSrcIds){
      sum += dstCounts(i)
      i += 1
      dstPtrs(i) = sum
    }
    InBlock(uniqueSrcIds, dstPtrs, dstEncodedIndices, ratings)
  }

  // srcIds中元素升序排列,同时其他数组也根据srcId顺序做相应调整
  def sort() = {
    val srcIdsCopy = srcIds.clone()
    val id2IndexMap = srcIdsCopy.zipWithIndex.sortBy(_._1).zipWithIndex.map(x => (x._1._2, x._2)).toMap
    val dstEncodedIndicesCopy = dstEncodedIndices.clone()
    val ratingsCopy = ratings.clone()

    var i = 0
    while(i < length){
      val newIndex = id2IndexMap.get(i).get
      srcIds(newIndex) = srcIdsCopy(i)
      dstEncodedIndices(newIndex) = dstEncodedIndicesCopy(i)
      ratings(newIndex) = ratingsCopy(i)
      i += 1
    }
  }

  def printElem() = {
    println("srcIds new:")
    srcIds.foreach(println)
    println("dstEncodedIndices new:")
    dstEncodedIndices.foreach(println)
    println("ratings new:")
    ratings.foreach(println)
  }

}

// in-block block,用于计算(user/item)因子,使用CSC类似的格式存放in-link信息。因此我们只使用一个正则方程类计算一个个src factors
// srcIds: src ids
// dstPtrs: (dstPtrs(i), dstPtrs(i+1))格式的dst indices
// dstEncodedIndices: encoded dst indices
case class InBlock(
                  srcIds: Array[Int],
                  dstPtrs: Array[Int],
                  dstEncodedIndices: Array[Int],
                  ratings: Array[Double]){
  def size: Int = ratings.length
  require(dstEncodedIndices.length == size)
  require(dstPtrs.length == srcIds.length + 1)
}

0

阅读 收藏 喜欢 打印举报/Report
  

新浪BLOG意见反馈留言板 欢迎批评指正

新浪简介 | About Sina | 广告服务 | 联系我们 | 招聘信息 | 网站律师 | SINA English | 产品答疑

新浪公司 版权所有