spark
算法原理
協同過濾是用來對用戶的興趣偏好做預測的一種方法。在Spark中實現的是基于潛在因子模型的協同過濾。用戶對特定物品的偏好往往可以用評分的形式給出,評分矩陣r的行數對應用戶數量,列數對應物品總數,比如以下4個用戶對四個電影評分:
本方法的核心在于把評分矩陣分解為用戶偏好矩陣(x)和物品偏好因子矩陣(y):
我們的目標是找到最佳的x和y矩陣,使得這兩個矩陣相乘時得到的預測矩陣與原始評分矩陣r之間的誤差最小。轉化為數學描述,就是使得以下目標函數最小化:
該目標函數由兩部分構成,前半部分是平方誤差,后半部分使用L2正則化,引入 λ 常數,對模型的復雜度進行控制,防止過度擬合訓練數據。
Spark使用的是帶正則化矩陣分解,優化函數的方式選用的是交叉最小二乘法ALS(alternative least squares),它的一般執行步驟如下:
- 用隨機數初始化物品偏好因子矩陣y
- 固定y,找到可以最小化目標函數的用戶偏好矩陣x
- 固定x,類同步驟2,找到最小化目標函數的物品偏好因子矩陣y
- 重復步驟2和3,直到滿足算法收斂條件
ALS spark mllib 代碼
參數詳解:
輸入參數名稱 | 數據格式 | 必填/可選/固定 | 默認值 | 取值范圍 | 備注 |
---|---|---|---|---|---|
userCol | String | 必填 | 用戶ID | ||
itemCol | String | 必填 | 物品ID | ||
ratingCol | String | 必填 | 用戶給物品的評分列 | ||
rank | Int | 可選 | 10 | ≥1 | 潛在因子數量,最優值需要根據具體數據制定 |
maxIter | Int | 可選 | 10 | ≥0 | 最大循環次數 |
lambda | Double | 可選 | 0.01 | ≥0 | 正則化參數 λ |
numUserBlocks | Int | 可選 | 10 | ≥1 | 把用戶偏好矩陣拆分成小塊以滿足并行化需求 |
numItemBlocks | Int | 可選 | 10 | ≥1 | 把物品偏好因子矩陣拆分成塊以滿足并行化需求 |
implicitPrefs | Boolean | 必填 | false | false或true | 是否為推測出來的用戶偏好(比如,如果一個用戶購買過物品A,則推測對A有偏好)。 |
alpha | Double | 可選 | 1.0 | ≥0 | implicitPrefs為true時,根據用戶的評分,在confidence基準值之上,進行額外加分 |
nonnegative | Boolean | 可選 | false | false或true | 在最小平方差優化時,是否加以“非負值”限定 |
輸出參數名稱 | 數據格式 | 必填/可選/固定 | 備注 |
---|---|---|---|
prediction | Float | 固定 | 預測值列 |
spark代碼
// 讀入數據
val ratings = sparkContext.textFile("data/mllib/als/sample_movielens_ratings.txt").map(
_.split("::") match { case Array(user, product, rating, timeStamp) =>
Rating(user.toInt, product.toInt, rating.toDouble)
})
df = sqlContext.createDataFrame(ratings)
// 參數值設定
val userCol = "user"
val itemCol = "product"
val ratingCol = "rating"
val rank = 10
val maxIter = 10
val regParam = 0.1
val numUserBlocks = 10
val numItemBlocks = 10
val implicitPrefs = false
val alpha = 1.0
val nonnegative = false
// 建立模型
val als = new ALS(userCol, itemCol, ratingCol, rank, maxIter, numUserBlocks, numItemBlocks, implicitPrefs, alpha, nonnegative)
// 模型訓練
val alsModel = als.fit(df)
// 進行預測
val predResult = alsModel.transform(df)
val toDouble = udf[Double, Float]( _.toDouble)
val newPredResult = predResult.withColumn("predictionNew", toDouble(predResult("prediction")))
// 計算RMSE(模型評價)
val predRDD = newPredResult.select("predictionNew", "rating").rdd.map(r => (r.getDouble(0), r.getDouble(1)))
val regMetric = new RegressionMetrics(predRDD)
val rmseSpark = regMetric.rootMeanSquaredError
println(s"RMSE for ALS model: ${rmseSpark}")
本地實例
1.測試數據
userID | itemID | ratings |
---|---|---|
101 | 1001 | 4.0 |
101 | 1002 | 2.5 |
101 | 1004 | 3.0 |
101 | 1007 | 1.5 |
101 | 1010 | 4.0 |
101 | 1016 | 3.5 |
101 | 1022 | 4.0 |
102 | 1002 | 2.5 |
102 | 1003 | 1.0 |
102 | 1004 | 3.5 |
102 | 1006 | 2.0 |
102 | 1009 | 2.5 |
102 | 1011 | 4.0 |
102 | 1013 | 3.5 |
102 | 1015 | 4.0 |
102 | 1017 | 4.5 |
102 | 1022 | 5.0 |
103 | 1003 | 1.5 |
103 | 1005 | 1.0 |
103 | 1006 | 3.5 |
103 | 1008 | 2.0 |
103 | 1010 | 4.5 |
103 | 1014 | 3.0 |
103 | 1015 | 3.5 |
103 | 1021 | 5.0 |
103 | 1022 | 1.5 |
103 | 1023 | 5.0 |
104 | 1001 | 0.5 |
104 | 1003 | 3.0 |
104 | 1004 | 1.5 |
104 | 1007 | 1.0 |
104 | 1008 | 2.5 |
104 | 1011 | 1.0 |
104 | 1015 | 3.5 |
104 | 1018 | 4.0 |
104 | 1019 | 1.5 |
104 | 1020 | 3.0 |
2.訓練
package ALSdemo
import java.io.File
import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.recommendation.{ALS, Rating}
import org.apache.spark.rdd.RDD
object alsTest {
//屏蔽不必要的日志顯示在終端上
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)
def main(args: Array[String]): Unit = {
//給用戶推薦
val conf = new SparkConf().setMaster("local[2]").setAppName("als_test_wy")
val sc = new SparkContext(conf)
val myModelPath = "E:\\Spark\\scala-data\\Model\\alsTest"
val data = sc.textFile("E:\\Spark\\scala-data\\CBRec\\als_rating_test.txt")
val ratings: RDD[Rating] = data.map(_.split("#") match { case Array(user, item, rate) =>
Rating(user.toInt, item.toInt, rate.toDouble)
})
ratings.filter(x => x.user == 101).foreach(println)
// Build the recommendation model using ALS
val rank = 5
val numIterations = 10
val model = ALS.train(ratings, rank, numIterations, 0.01)
val recommendProducts: Array[Rating] = model.recommendProducts(101, 10)
for (r <- recommendProducts) {
println(r.toString)
}
val path: File = new File(myModelPath)
dirDel(path) //刪除原模型保存的文件,不刪除新模型保存會報錯
model.save(sc, myModelPath)
}
//刪除模型目錄和文件
def dirDel(path: File) {
if (!path.exists())
return
else if (path.isFile) {
path.delete()
return
}
val file: Array[File] = path.listFiles()
for (d <- file) {
dirDel(d)
}
path.delete()
}
}
3.調用模型預測
package ALSdemo
import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel
object alsLoadModelTest {
//屏蔽不必要的日志顯示在終端上
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)
def main(args: Array[String]): Unit = {
//給用戶推薦
val conf = new SparkConf().setMaster("local").setAppName("rem_test")
val sc = new SparkContext(conf)
val myModelPath = "E:\\Spark\\scala-data\\Model\\alsTest"
val model = MatrixFactorizationModel.load(sc, myModelPath)
val recommendProducts = model.recommendProducts(102, 12)
for (r <- recommendProducts) {
println(r.toString)
}
}
}
4.預測結果
user101:
Rating(101,1010,4.000591419102056)
Rating(101,1022,3.9969496458948193)
Rating(101,1001,3.9772784041229023)
Rating(101,1015,3.5501142515465673)
Rating(101,1016,3.4999375705609506)
Rating(101,1004,3.0070683414579378)
Rating(101,1006,2.64035448857031)
Rating(101,1021,2.5037825017384447)
Rating(101,1023,2.5037825017384447)
Rating(101,1002,2.4961448711069245)
user102:
Rating(102,1022,5.004920261743269)
Rating(102,1017,4.503333959561672)
Rating(102,1015,3.986380420809543)
Rating(102,1011,3.9743258175532787)
Rating(102,1013,3.5025929824248347)
Rating(102,1004,3.481621012016846)
Rating(102,1002,2.509660995203404)
Rating(102,1009,2.5018521424997786)
Rating(102,1006,1.9992398840722876)
Rating(102,1003,1.019633914552224)
Rating(102,1018,0.5646255853665232)
Rating(102,1016,0.49503960012882686)