in wayang-api/wayang-api-json/src/main/scala/builder/JsonPlanBuilder.scala [544:602]
private def visit(operator: KMeansFromJson, dataQuanta: DataQuanta[Any]): DataQuanta[Any] = {
val distanceUdf = SerializableLambda2.createLambda[Any, Any, Any](operator.data.distanceUdf)
val sumUdf = SerializableLambda2.createLambda[Any, Any, Any](operator.data.sumUdf)
val divideUdf = SerializableLambda2.createLambda[Any, Any, Any](operator.data.divideUdf)
val initialCentroidsUdf = SerializableLambda.createLambda[Any, Any](operator.data.initialCentroidsUdf)
class SelectNearestCentroid extends ExtendedSerializableFunction[Any, Any] {
var centroids: java.util.Collection[(Any, Int)] = _
override def open(executionCtx: ExecutionContext): Unit = {
centroids = executionCtx.getBroadcast[(Any, Int)]("centroids")
}
override def apply(point: Any): Any = {
var minDistance = Double.PositiveInfinity
var nearestCentroidId: Int = -1
for (centroid <- centroids.asScala) {
val distance = distanceUdf(point, centroid._1).asInstanceOf[Double]
if (distance < minDistance) {
minDistance = distance
nearestCentroidId = centroid._2
}
}
TaggedPointCounter(point, nearestCentroidId)
}
}
case class TaggedPointCounter(point: Any, centroidId: Int, count: Int = 1) {
def +(that: TaggedPointCounter): TaggedPointCounter =
TaggedPointCounter(
sumUdf(this.point, that.point),
this.centroidId,
this.count + that.count
)
def average: (Any, Int) = (divideUdf(point, count), centroidId)
}
val initialCentroids = planBuilder.loadCollection(initialCentroidsUdf(operator.data.k).asInstanceOf[Iterable[Any]])
val finalCentroids = initialCentroids.map(_.asInstanceOf[(Any, Int)]).repeat(
operator.data.maxIterations,
{ currentCentroids =>
dataQuanta
.mapJava(
new SelectNearestCentroid
// udfLoad = LoadProfileEstimators.createFromSpecification("wayang.apps.kmeans.udfs.select-centroid.load", configuration)
)
.withBroadcast(currentCentroids, "centroids")
.reduceByKey(_.asInstanceOf[TaggedPointCounter].centroidId, _.asInstanceOf[TaggedPointCounter] + _.asInstanceOf[TaggedPointCounter])
.withCardinalityEstimator(operator.data.k)
.map(_.asInstanceOf[TaggedPointCounter].average)
}
)
finalCentroids.asInstanceOf[DataQuanta[Any]]
}