def addDistanceBandColumn()

in spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala [65:152]


  def addDistanceBandColumn(
      dataframe: DataFrame,
      threshold: Double,
      binary: Boolean = true,
      alpha: Double = -1.0,
      includeZeroDistanceNeighbors: Boolean = false,
      includeSelf: Boolean = false,
      selfWeight: Double = 1.0,
      geometry: String = null,
      useSpheroid: Boolean = false,
      savedAttributes: Seq[String] = null,
      resultName: String = "weights"): DataFrame = {

    require(threshold >= 0, "Threshold must be greater than or equal to 0")
    require(alpha < 0, "Alpha must be less than 0")

    val geometryColumn = geometry match {
      case null => getGeometryColumnName(dataframe.schema)
      case _ =>
        require(
          dataframe.schema.fields.exists(_.name == geometry),
          s"Geometry column $geometry not found in dataframe")
        geometry
    }

    // Always include the geometry column in the saved attributes
    val savedAttributesWithGeom =
      if (savedAttributes == null) null
      else if (!savedAttributes.contains(geometryColumn)) savedAttributes :+ geometryColumn
      else savedAttributes

    val distanceFunction: (Column, Column) => Column =
      if (useSpheroid) ST_DistanceSpheroid else ST_Distance

    val joinCondition = if (includeZeroDistanceNeighbors) {
      distanceFunction(col(s"l.$geometryColumn"), col(s"r.$geometryColumn")) <= threshold
    } else {
      distanceFunction(
        col(s"l.$geometryColumn"),
        col(s"r.$geometryColumn")) <= threshold && distanceFunction(
        col(s"l.$geometryColumn"),
        col(s"r.$geometryColumn")) > 0
    }

    val formattedDataFrame = dataframe.withColumn(ID_COLUMN, sha2(to_json(struct("*")), 256))

    formattedDataFrame
      .alias("l")
      .join(
        formattedDataFrame.alias("r"),
        joinCondition && col(s"l.$ID_COLUMN") =!= col(
          s"r.$ID_COLUMN"
        ), // we will add self back later if self.includeSelf
        "left")
      .select(
        col(s"l.$ID_COLUMN"),
        struct("l.*").alias("left_contents"),
        struct(
          (
            savedAttributesWithGeom match {
              case null => struct(col("r.*")).dropFields(ID_COLUMN)
              case _ =>
                struct(savedAttributesWithGeom.map(c => col(s"r.$c")): _*)
            }
          ).alias("neighbor"),
          if (!binary)
            pow(distanceFunction(col(s"l.$geometryColumn"), col(s"r.$geometryColumn")), alpha)
              .alias("value")
          else lit(1.0).alias("value")).alias("weight"))
      .groupBy(s"l.$ID_COLUMN")
      .agg(
        first("left_contents").alias("left_contents"),
        concat(
          collect_list(col("weight")),
          if (includeSelf)
            array(struct(
              (savedAttributesWithGeom match {
                case null => first("left_contents").dropFields(ID_COLUMN)
                case _ =>
                  struct(
                    savedAttributesWithGeom.map(c => first(s"left_contents.$c").alias(c)): _*)
              }).alias("neighbor"),
              lit(selfWeight).alias("value")))
          else array()).alias(resultName))
      .select("left_contents.*", resultName)
      .drop(ID_COLUMN)
      .withColumn(resultName, filter(col(resultName), _(f"neighbor")(geometryColumn).isNotNull))
  }