in spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialFilterPushDownForGeoParquet.scala [103:166]
private def translateToGeoParquetSpatialFilter(
predicate: Expression,
pushableColumn: PushableColumnBase): Option[GeoParquetSpatialFilter] = {
predicate match {
case And(left, right) =>
val spatialFilterLeft = translateToGeoParquetSpatialFilter(left, pushableColumn)
val spatialFilterRight = translateToGeoParquetSpatialFilter(right, pushableColumn)
(spatialFilterLeft, spatialFilterRight) match {
case (Some(l), Some(r)) => Some(AndFilter(l, r))
case (Some(l), None) => Some(l)
case (None, Some(r)) => Some(r)
case _ => None
}
case Or(left, right) =>
for {
spatialFilterLeft <- translateToGeoParquetSpatialFilter(left, pushableColumn)
spatialFilterRight <- translateToGeoParquetSpatialFilter(right, pushableColumn)
} yield OrFilter(spatialFilterLeft, spatialFilterRight)
case Not(_) => None
case ST_Contains(Seq(pushableColumn(name), Literal(v, _))) =>
Some(LeafFilter(unquote(name), SpatialPredicate.COVERS, GeometryUDT.deserialize(v)))
case ST_Contains(Seq(Literal(v, _), pushableColumn(name))) =>
Some(LeafFilter(unquote(name), SpatialPredicate.INTERSECTS, GeometryUDT.deserialize(v)))
case ST_Covers(Seq(pushableColumn(name), Literal(v, _))) =>
Some(LeafFilter(unquote(name), SpatialPredicate.COVERS, GeometryUDT.deserialize(v)))
case ST_Covers(Seq(Literal(v, _), pushableColumn(name))) =>
Some(LeafFilter(unquote(name), SpatialPredicate.INTERSECTS, GeometryUDT.deserialize(v)))
case ST_Within(Seq(pushableColumn(name), Literal(v, _))) =>
Some(LeafFilter(unquote(name), SpatialPredicate.INTERSECTS, GeometryUDT.deserialize(v)))
case ST_Within(Seq(Literal(v, _), pushableColumn(name))) =>
Some(LeafFilter(unquote(name), SpatialPredicate.COVERS, GeometryUDT.deserialize(v)))
case ST_CoveredBy(Seq(pushableColumn(name), Literal(v, _))) =>
Some(LeafFilter(unquote(name), SpatialPredicate.INTERSECTS, GeometryUDT.deserialize(v)))
case ST_CoveredBy(Seq(Literal(v, _), pushableColumn(name))) =>
Some(LeafFilter(unquote(name), SpatialPredicate.COVERS, GeometryUDT.deserialize(v)))
case ST_Equals(_) | ST_OrderingEquals(_) =>
for ((name, value) <- resolveNameAndLiteral(predicate.children, pushableColumn))
yield LeafFilter(unquote(name), SpatialPredicate.COVERS, GeometryUDT.deserialize(value))
case ST_Intersects(_) | ST_Crosses(_) | ST_Overlaps(_) | ST_Touches(_) =>
for ((name, value) <- resolveNameAndLiteral(predicate.children, pushableColumn))
yield LeafFilter(
unquote(name),
SpatialPredicate.INTERSECTS,
GeometryUDT.deserialize(value))
case LessThan(ST_Distance(distArgs), Literal(d, DoubleType)) =>
for ((name, value) <- resolveNameAndLiteral(distArgs, pushableColumn))
yield distanceFilter(name, GeometryUDT.deserialize(value), d.asInstanceOf[Double])
case LessThanOrEqual(ST_Distance(distArgs), Literal(d, DoubleType)) =>
for ((name, value) <- resolveNameAndLiteral(distArgs, pushableColumn))
yield distanceFilter(name, GeometryUDT.deserialize(value), d.asInstanceOf[Double])
case _ => None
}
}