in spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala [65:152]
def addDistanceBandColumn(
dataframe: DataFrame,
threshold: Double,
binary: Boolean = true,
alpha: Double = -1.0,
includeZeroDistanceNeighbors: Boolean = false,
includeSelf: Boolean = false,
selfWeight: Double = 1.0,
geometry: String = null,
useSpheroid: Boolean = false,
savedAttributes: Seq[String] = null,
resultName: String = "weights"): DataFrame = {
require(threshold >= 0, "Threshold must be greater than or equal to 0")
require(alpha < 0, "Alpha must be less than 0")
val geometryColumn = geometry match {
case null => getGeometryColumnName(dataframe.schema)
case _ =>
require(
dataframe.schema.fields.exists(_.name == geometry),
s"Geometry column $geometry not found in dataframe")
geometry
}
// Always include the geometry column in the saved attributes
val savedAttributesWithGeom =
if (savedAttributes == null) null
else if (!savedAttributes.contains(geometryColumn)) savedAttributes :+ geometryColumn
else savedAttributes
val distanceFunction: (Column, Column) => Column =
if (useSpheroid) ST_DistanceSpheroid else ST_Distance
val joinCondition = if (includeZeroDistanceNeighbors) {
distanceFunction(col(s"l.$geometryColumn"), col(s"r.$geometryColumn")) <= threshold
} else {
distanceFunction(
col(s"l.$geometryColumn"),
col(s"r.$geometryColumn")) <= threshold && distanceFunction(
col(s"l.$geometryColumn"),
col(s"r.$geometryColumn")) > 0
}
val formattedDataFrame = dataframe.withColumn(ID_COLUMN, sha2(to_json(struct("*")), 256))
formattedDataFrame
.alias("l")
.join(
formattedDataFrame.alias("r"),
joinCondition && col(s"l.$ID_COLUMN") =!= col(
s"r.$ID_COLUMN"
), // we will add self back later if self.includeSelf
"left")
.select(
col(s"l.$ID_COLUMN"),
struct("l.*").alias("left_contents"),
struct(
(
savedAttributesWithGeom match {
case null => struct(col("r.*")).dropFields(ID_COLUMN)
case _ =>
struct(savedAttributesWithGeom.map(c => col(s"r.$c")): _*)
}
).alias("neighbor"),
if (!binary)
pow(distanceFunction(col(s"l.$geometryColumn"), col(s"r.$geometryColumn")), alpha)
.alias("value")
else lit(1.0).alias("value")).alias("weight"))
.groupBy(s"l.$ID_COLUMN")
.agg(
first("left_contents").alias("left_contents"),
concat(
collect_list(col("weight")),
if (includeSelf)
array(struct(
(savedAttributesWithGeom match {
case null => first("left_contents").dropFields(ID_COLUMN)
case _ =>
struct(
savedAttributesWithGeom.map(c => first(s"left_contents.$c").alias(c)): _*)
}).alias("neighbor"),
lit(selfWeight).alias("value")))
else array()).alias(resultName))
.select("left_contents.*", resultName)
.drop(ID_COLUMN)
.withColumn(resultName, filter(col(resultName), _(f"neighbor")(geometryColumn).isNotNull))
}