in core/src/main/scala/com/microsoft/azure/synapse/ml/recommendation/RankingTrainValidationSplit.scala [160:240]
private def filterByUserRatingCount(dataset: Dataset[_]): DataFrame = dataset
.groupBy(getItemCol)
.agg(col(getItemCol), count(col(getUserCol)).alias("ncustomers"))
.where(col("ncustomers") >= getMinRatingsI)
.join(dataset, getItemCol)
.drop("ncustomers")
.cache()
def filterRatings(dataset: Dataset[_]): DataFrame = filterByUserRatingCount(dataset)
.join(filterByItemCount(dataset), $(userCol))
def splitDF(dataset: DataFrame): Array[DataFrame] = {
val shuffleFlag = true
val shuffleBC = dataset.sparkSession.sparkContext.broadcast(shuffleFlag)
if (dataset.columns.contains(getRatingCol)) {
val wrapColumn = udf((itemId: Double, rating: Double) => Array(itemId, rating))
val sliceudf = udf(
(r: Seq[Array[Double]]) => r.slice(0, math.round(r.length * $(trainRatio)).toInt))
val shuffle = udf((r: Seq[Array[Double]]) =>
if (shuffleBC.value) Random.shuffle(r)
else r
)
val dropudf = udf((r: Seq[Array[Double]]) => r.drop(math.round(r.length * $(trainRatio)).toInt))
val testds = dataset
.withColumn("itemIDRating", wrapColumn(col(getItemCol), col(getRatingCol)))
.groupBy(col(getUserCol))
.agg(collect_list(col("itemIDRating")))
.withColumn("shuffle", shuffle(col("collect_list(itemIDRating)")))
.withColumn("train", sliceudf(col("shuffle")))
.withColumn("test", dropudf(col("shuffle")))
.drop(col("collect_list(itemIDRating)")).drop(col("shuffle"))
//.cache()
val train = testds
.select(getUserCol, "train")
.withColumn("itemIdRating", explode(col("train")))
.drop("train")
.withColumn(getItemCol, col("itemIdRating").getItem(0))
.withColumn(getRatingCol, col("itemIdRating").getItem(1))
.drop("itemIdRating")
val test = testds
.select(getUserCol, "test")
.withColumn("itemIdRating", explode(col("test")))
.drop("test")
.withColumn(getItemCol, col("itemIdRating").getItem(0))
.withColumn(getRatingCol, col("itemIdRating").getItem(1))
.drop("itemIdRating")
Array(train, test)
}
else {
val sliceudf = udf(
(r: Seq[Double]) => r.slice(0, math.round(r.length * $(trainRatio)).toInt))
val dropudf = udf((r: Seq[Double]) => r.drop(math.round(r.length * $(trainRatio)).toInt))
val testDS = dataset
.groupBy(col(getUserCol))
.agg(collect_list(col(getItemCol)).alias("shuffle"))
.withColumn("train", sliceudf(col("shuffle")))
.withColumn("test", dropudf(col("shuffle")))
.drop(col(s"collect_list($getItemCol")).drop(col("shuffle"))
.cache()
val train = testDS
.select(getUserCol, "train")
.withColumn(getItemCol, explode(col("train")))
.drop("train")
val test = testDS
.select(getUserCol, "test")
.withColumn(getItemCol, explode(col("test")))
.drop("test")
Array(train, test)
}
}