def spark_matchup_driver()

in analysis/webservice/algorithms_spark/MatchupDoms.py [0:0]


def spark_matchup_driver(tile_ids, bounding_wkt, primary_ds_name, secondary_ds_names, parameter, depth_min, depth_max,
                         time_tolerance, radius_tolerance, platforms, match_once, tile_service_factory, sc=None):
    from functools import partial

    with DRIVER_LOCK:
        # Broadcast parameters
        primary_b = sc.broadcast(primary_ds_name)
        secondary_b = sc.broadcast(secondary_ds_names)
        depth_min_b = sc.broadcast(float(depth_min) if depth_min is not None else None)
        depth_max_b = sc.broadcast(float(depth_max) if depth_max is not None else None)
        tt_b = sc.broadcast(time_tolerance)
        rt_b = sc.broadcast(float(radius_tolerance))
        platforms_b = sc.broadcast(platforms)
        bounding_wkt_b = sc.broadcast(bounding_wkt)
        parameter_b = sc.broadcast(parameter)

        # Parallelize list of tile ids
        rdd = sc.parallelize(tile_ids, determine_parllelism(len(tile_ids)))

    # Map Partitions ( list(tile_id) )
    rdd_filtered = rdd.mapPartitions(
        partial(
            match_satellite_to_insitu,
            primary_b=primary_b,
            secondary_b=secondary_b,
            parameter_b=parameter_b,
            tt_b=tt_b,
            rt_b=rt_b,
            platforms_b=platforms_b,
            bounding_wkt_b=bounding_wkt_b,
            depth_min_b=depth_min_b,
            depth_max_b=depth_max_b,
            tile_service_factory=tile_service_factory
        ),
        preservesPartitioning=True
    ).filter(
        lambda p_m_tuple: abs(
            iso_time_to_epoch(p_m_tuple[0].time) - iso_time_to_epoch(p_m_tuple[1].time)
        ) <= time_tolerance
    )

    if match_once:
        # Only the 'nearest' point for each primary should be returned. Add an extra map/reduce which calculates
        # the distance and finds the minimum

        # Method used for calculating the distance between 2 DomsPoints
        from pyproj import Geod

        def dist(primary, matchup):
            wgs84_geod = Geod(ellps='WGS84')
            lat1, lon1 = (primary.latitude, primary.longitude)
            lat2, lon2 = (matchup.latitude, matchup.longitude)
            az12, az21, distance = wgs84_geod.inv(lon1, lat1, lon2, lat2)
            return distance

        rdd_filtered = rdd_filtered \
            .map(lambda primary_matchup: tuple([primary_matchup[0], tuple([primary_matchup[1], dist(primary_matchup[0], primary_matchup[1])])])) \
            .reduceByKey(lambda match_1, match_2: match_1 if match_1[1] < match_2[1] else match_2) \
            .mapValues(lambda x: [x[0]])
    else:
        rdd_filtered = rdd_filtered \
            .combineByKey(lambda value: [value],  # Create 1 element list
                          lambda value_list, value: value_list + [value],  # Add 1 element to list
                          lambda value_list_a, value_list_b: value_list_a + value_list_b)  # Add two lists together

    result_as_map = rdd_filtered.collectAsMap()

    return result_as_map