private def visit()

in wayang-api/wayang-api-json/src/main/scala/builder/JsonPlanBuilder.scala [544:602]


  private def visit(operator: KMeansFromJson, dataQuanta: DataQuanta[Any]): DataQuanta[Any] = {

    val distanceUdf = SerializableLambda2.createLambda[Any, Any, Any](operator.data.distanceUdf)
    val sumUdf = SerializableLambda2.createLambda[Any, Any, Any](operator.data.sumUdf)
    val divideUdf = SerializableLambda2.createLambda[Any, Any, Any](operator.data.divideUdf)
    val initialCentroidsUdf = SerializableLambda.createLambda[Any, Any](operator.data.initialCentroidsUdf)

    class SelectNearestCentroid extends ExtendedSerializableFunction[Any, Any] {

      var centroids: java.util.Collection[(Any, Int)] = _

      override def open(executionCtx: ExecutionContext): Unit = {
        centroids = executionCtx.getBroadcast[(Any, Int)]("centroids")
      }

      override def apply(point: Any): Any = {
        var minDistance = Double.PositiveInfinity
        var nearestCentroidId: Int = -1
        for (centroid <- centroids.asScala) {
          val distance = distanceUdf(point, centroid._1).asInstanceOf[Double]
          if (distance < minDistance) {
            minDistance = distance
            nearestCentroidId = centroid._2
          }
        }
        TaggedPointCounter(point, nearestCentroidId)
      }
    }

    case class TaggedPointCounter(point: Any, centroidId: Int, count: Int = 1) {
      def +(that: TaggedPointCounter): TaggedPointCounter =
        TaggedPointCounter(
          sumUdf(this.point, that.point),
          this.centroidId,
          this.count + that.count
        )

      def average: (Any, Int) = (divideUdf(point, count), centroidId)
    }

    val initialCentroids = planBuilder.loadCollection(initialCentroidsUdf(operator.data.k).asInstanceOf[Iterable[Any]])

    val finalCentroids = initialCentroids.map(_.asInstanceOf[(Any, Int)]).repeat(
      operator.data.maxIterations,
      { currentCentroids =>
        dataQuanta
          .mapJava(
            new SelectNearestCentroid
            // udfLoad = LoadProfileEstimators.createFromSpecification("wayang.apps.kmeans.udfs.select-centroid.load", configuration)
          )
          .withBroadcast(currentCentroids, "centroids")
          .reduceByKey(_.asInstanceOf[TaggedPointCounter].centroidId, _.asInstanceOf[TaggedPointCounter] + _.asInstanceOf[TaggedPointCounter])
          .withCardinalityEstimator(operator.data.k)
          .map(_.asInstanceOf[TaggedPointCounter].average)
      }
    )

    finalCentroids.asInstanceOf[DataQuanta[Any]]
  }