in spark/common/src/main/scala/org/apache/sedona/stats/clustering/DBSCAN.scala [60:144]
def dbscan(
dataframe: DataFrame,
epsilon: Double,
minPts: Int,
geometry: String = null,
includeOutliers: Boolean = true,
useSpheroid: Boolean = false,
isCoreColumnName: String = "isCore",
clusterColumnName: String = "cluster"): DataFrame = {
val geometryCol = geometry match {
case null => getGeometryColumnName(dataframe.schema)
case _ => geometry
}
validateInputs(dataframe, epsilon, minPts, geometryCol)
val distanceFunction: (Column, Column) => Column =
if (useSpheroid) ST_DistanceSpheroid else ST_Distance
val hasIdColumn = dataframe.columns.contains("id")
val idDataframe = if (hasIdColumn) {
dataframe
.withColumnRenamed("id", ID_COLUMN)
.withColumn("id", sha2(to_json(struct("*")), 256))
} else {
dataframe.withColumn("id", sha2(to_json(struct("*")), 256))
}
val isCorePointsDF = idDataframe
.alias("left")
.join(
idDataframe.alias("right"),
distanceFunction(col(s"left.$geometryCol"), col(s"right.$geometryCol")) <= epsilon)
.groupBy(col(s"left.id"))
.agg(
first(struct("left.*")).alias("leftContents"),
count(col(s"right.id")).alias("neighbors_count"),
collect_list(col(s"right.id")).alias("neighbors"))
.withColumn(isCoreColumnName, col("neighbors_count") >= lit(minPts))
.select("leftContents.*", "neighbors", isCoreColumnName)
.checkpoint()
val corePointsDF = isCorePointsDF.filter(col(isCoreColumnName))
val borderPointsDF = isCorePointsDF.filter(!col(isCoreColumnName))
val coreEdgesDf = corePointsDF
.select(col("id").alias("src"), explode(col("neighbors")).alias("dst"))
.alias("left")
.join(corePointsDF.alias("right"), col("left.dst") === col(s"right.id"))
.select(col("left.src"), col(s"right.id").alias("dst"))
val connectedComponentsDF = GraphFrame(corePointsDF, coreEdgesDf).connectedComponents.run
val borderComponentsDF = borderPointsDF
.select(struct("*").alias("leftContent"), explode(col("neighbors")).alias("neighbor"))
.join(connectedComponentsDF.alias("right"), col("neighbor") === col(s"right.id"))
.groupBy(col("leftContent.id"))
.agg(
first(col("leftContent")).alias("leftContent"),
min(col(s"right.component")).alias("component"))
.select("leftContent.*", "component")
val clusteredPointsDf = borderComponentsDF.union(connectedComponentsDF)
val outliersDf = idDataframe
.join(clusteredPointsDf, Seq("id"), "left_anti")
.withColumn(isCoreColumnName, lit(false))
.withColumn("component", lit(-1))
.withColumn("neighbors", array().cast("array<string>"))
val completedDf = (
if (includeOutliers) clusteredPointsDf.unionByName(outliersDf)
else clusteredPointsDf
).withColumnRenamed("component", clusterColumnName)
val returnDf = if (hasIdColumn) {
completedDf.drop("neighbors", "id").withColumnRenamed(ID_COLUMN, "id")
} else {
completedDf.drop("neighbors", "id")
}
returnDf
}