private def filterByUserRatingCount()

in core/src/main/scala/com/microsoft/azure/synapse/ml/recommendation/RankingTrainValidationSplit.scala [160:240]


  private def filterByUserRatingCount(dataset: Dataset[_]): DataFrame = dataset
    .groupBy(getItemCol)
    .agg(col(getItemCol), count(col(getUserCol)).alias("ncustomers"))
    .where(col("ncustomers") >= getMinRatingsI)
    .join(dataset, getItemCol)
    .drop("ncustomers")
    .cache()

  def filterRatings(dataset: Dataset[_]): DataFrame = filterByUserRatingCount(dataset)
    .join(filterByItemCount(dataset), $(userCol))

  def splitDF(dataset: DataFrame): Array[DataFrame] = {
    val shuffleFlag = true
    val shuffleBC = dataset.sparkSession.sparkContext.broadcast(shuffleFlag)

    if (dataset.columns.contains(getRatingCol)) {
      val wrapColumn = udf((itemId: Double, rating: Double) => Array(itemId, rating))

      val sliceudf = udf(
        (r: Seq[Array[Double]]) => r.slice(0, math.round(r.length * $(trainRatio)).toInt))

      val shuffle = udf((r: Seq[Array[Double]]) =>
        if (shuffleBC.value) Random.shuffle(r)
        else r
      )
      val dropudf = udf((r: Seq[Array[Double]]) => r.drop(math.round(r.length * $(trainRatio)).toInt))

      val testds = dataset
        .withColumn("itemIDRating", wrapColumn(col(getItemCol), col(getRatingCol)))
        .groupBy(col(getUserCol))
        .agg(collect_list(col("itemIDRating")))
        .withColumn("shuffle", shuffle(col("collect_list(itemIDRating)")))
        .withColumn("train", sliceudf(col("shuffle")))
        .withColumn("test", dropudf(col("shuffle")))
        .drop(col("collect_list(itemIDRating)")).drop(col("shuffle"))
      //.cache()

      val train = testds
        .select(getUserCol, "train")
        .withColumn("itemIdRating", explode(col("train")))
        .drop("train")
        .withColumn(getItemCol, col("itemIdRating").getItem(0))
        .withColumn(getRatingCol, col("itemIdRating").getItem(1))
        .drop("itemIdRating")

      val test = testds
        .select(getUserCol, "test")
        .withColumn("itemIdRating", explode(col("test")))
        .drop("test")
        .withColumn(getItemCol, col("itemIdRating").getItem(0))
        .withColumn(getRatingCol, col("itemIdRating").getItem(1))
        .drop("itemIdRating")

      Array(train, test)
    }
    else {
      val sliceudf = udf(
        (r: Seq[Double]) => r.slice(0, math.round(r.length * $(trainRatio)).toInt))
      val dropudf = udf((r: Seq[Double]) => r.drop(math.round(r.length * $(trainRatio)).toInt))

      val testDS = dataset
        .groupBy(col(getUserCol))
        .agg(collect_list(col(getItemCol)).alias("shuffle"))
        .withColumn("train", sliceudf(col("shuffle")))
        .withColumn("test", dropudf(col("shuffle")))
        .drop(col(s"collect_list($getItemCol")).drop(col("shuffle"))
        .cache()

      val train = testDS
        .select(getUserCol, "train")
        .withColumn(getItemCol, explode(col("train")))
        .drop("train")

      val test = testDS
        .select(getUserCol, "test")
        .withColumn(getItemCol, explode(col("test")))
        .drop("test")

      Array(train, test)
    }
  }