data_preprocessing/sample_quadruplets/sample_for_counties.py [246:315]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def _sample_neighbor(img, a_lat, a_lon, neighborhood_radius, tile_radius, timestep, size_even):
    if neighborhood_radius is None:
        return _sample_distant_diff(img, tile_radius, timestep, size_even)

    _, _, img_h, img_w = img.shape
    while True:
        n_lat, n_lon = a_lat, a_lon
        while n_lat == a_lat and n_lon == a_lon:
            n_lat = np.random.randint(max(a_lat - neighborhood_radius, tile_radius),
                                      min(a_lat + neighborhood_radius, img_h - tile_radius))
            n_lon = np.random.randint(max(a_lon - neighborhood_radius, tile_radius),
                                      min(a_lon + neighborhood_radius, img_w - tile_radius))
        lat0, lat1, lon0, lon1 = _get_lat_lon_range(n_lat, n_lon, tile_radius, size_even)
        tile = img[:, timestep, lat0:lat1+1, lon0:lon1+1]
        if ma.count_masked(tile) == 0:
            break

    return tile, n_lat, n_lon


def _sample_distant_same(img, a_lat, a_lon, neighborhood_radius, distant_radius, tile_radius, timestep, size_even):
    if neighborhood_radius is None:
        return _sample_distant_diff(img, tile_radius, timestep, size_even)

    _, _, img_h, img_w = img.shape
    while True:
        d_lat, d_lon = a_lat, a_lon

        if distant_radius is None:
            while (d_lat >= a_lat - neighborhood_radius) and (d_lat <= a_lat + neighborhood_radius):
                d_lat = np.random.randint(tile_radius, img_h - tile_radius)
            while (d_lon >= a_lon - neighborhood_radius) and (d_lon <= a_lon + neighborhood_radius):
                d_lon = np.random.randint(tile_radius, img_w - tile_radius)
        else:
            while ((d_lat >= a_lat - neighborhood_radius) and (d_lat <= a_lat + neighborhood_radius)) \
                    or d_lat >= a_lat + distant_radius \
                    or d_lat <= a_lat - distant_radius:
                d_lat = np.random.randint(tile_radius, img_h - tile_radius)
            while ((d_lon >= a_lon - neighborhood_radius) and (d_lon <= a_lon + neighborhood_radius))\
                    or d_lon >= a_lon + distant_radius \
                    or d_lon <= a_lon - distant_radius:
                d_lon = np.random.randint(tile_radius, img_w - tile_radius)
        lat0, lat1, lon0, lon1 = _get_lat_lon_range(d_lat, d_lon, tile_radius, size_even)
        tile = img[:, timestep, lat0:lat1 + 1, lon0:lon1 + 1]
        if ma.count_masked(tile) == 0:
            break

    return tile, d_lat, d_lon


def _sample_distant_diff(img, tile_radius, timestep, size_even):
    _, _, img_h, img_w = img.shape
    while True:
        d_lat = np.random.randint(tile_radius, img_h - tile_radius)
        d_lon = np.random.randint(tile_radius, img_w - tile_radius)
        lat0, lat1, lon0, lon1 = _get_lat_lon_range(d_lat, d_lon, tile_radius, size_even)
        tile = img[:, timestep, lat0:lat1 + 1, lon0:lon1 + 1]
        if ma.count_masked(tile) == 0:
            break

    return tile, d_lat, d_lon


def plot_sampled_centers(lats, lons, img_shape, out_dir, name):
    c, t, h, w = img_shape

    plt.scatter(lons, lats, s=5)
    plt.axis([0, w, 0, h])
    plt.savefig('{}/{}.jpg'.format(out_dir, name))
    plt.close()
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



data_preprocessing/sample_quadruplets/sample_for_pretrained.py [177:246]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def _sample_neighbor(img, a_lat, a_lon, neighborhood_radius, tile_radius, timestep, size_even):
    if neighborhood_radius is None:
        return _sample_distant_diff(img, tile_radius, timestep, size_even)

    _, _, img_h, img_w = img.shape
    while True:
        n_lat, n_lon = a_lat, a_lon
        while n_lat == a_lat and n_lon == a_lon:
            n_lat = np.random.randint(max(a_lat - neighborhood_radius, tile_radius),
                                      min(a_lat + neighborhood_radius, img_h - tile_radius))
            n_lon = np.random.randint(max(a_lon - neighborhood_radius, tile_radius),
                                      min(a_lon + neighborhood_radius, img_w - tile_radius))
        lat0, lat1, lon0, lon1 = _get_lat_lon_range(n_lat, n_lon, tile_radius, size_even)
        tile = img[:, timestep, lat0:lat1+1, lon0:lon1+1]
        if ma.count_masked(tile) == 0:
            break

    return tile, n_lat, n_lon


def _sample_distant_same(img, a_lat, a_lon, neighborhood_radius, distant_radius, tile_radius, timestep, size_even):
    if neighborhood_radius is None:
        return _sample_distant_diff(img, tile_radius, timestep, size_even)

    _, _, img_h, img_w = img.shape
    while True:
        d_lat, d_lon = a_lat, a_lon

        if distant_radius is None:
            while (d_lat >= a_lat - neighborhood_radius) and (d_lat <= a_lat + neighborhood_radius):
                d_lat = np.random.randint(tile_radius, img_h - tile_radius)
            while (d_lon >= a_lon - neighborhood_radius) and (d_lon <= a_lon + neighborhood_radius):
                d_lon = np.random.randint(tile_radius, img_w - tile_radius)
        else:
            while ((d_lat >= a_lat - neighborhood_radius) and (d_lat <= a_lat + neighborhood_radius)) \
                    or d_lat >= a_lat + distant_radius \
                    or d_lat <= a_lat - distant_radius:
                d_lat = np.random.randint(tile_radius, img_h - tile_radius)
            while ((d_lon >= a_lon - neighborhood_radius) and (d_lon <= a_lon + neighborhood_radius))\
                    or d_lon >= a_lon + distant_radius \
                    or d_lon <= a_lon - distant_radius:
                d_lon = np.random.randint(tile_radius, img_w - tile_radius)
        lat0, lat1, lon0, lon1 = _get_lat_lon_range(d_lat, d_lon, tile_radius, size_even)
        tile = img[:, timestep, lat0:lat1 + 1, lon0:lon1 + 1]
        if ma.count_masked(tile) == 0:
            break

    return tile, d_lat, d_lon


def _sample_distant_diff(img, tile_radius, timestep, size_even):
    _, _, img_h, img_w = img.shape
    while True:
        d_lat = np.random.randint(tile_radius, img_h - tile_radius)
        d_lon = np.random.randint(tile_radius, img_w - tile_radius)
        lat0, lat1, lon0, lon1 = _get_lat_lon_range(d_lat, d_lon, tile_radius, size_even)
        tile = img[:, timestep, lat0:lat1 + 1, lon0:lon1 + 1]
        if ma.count_masked(tile) == 0:
            break

    return tile, d_lat, d_lon


def plot_sampled_centers(lats, lons, img_shape, out_dir, name):
    c, t, h, w = img_shape

    plt.scatter(lons, lats, s=5)
    plt.axis([0, w, 0, h])
    plt.savefig('{}/{}.jpg'.format(out_dir, name))
    plt.close()
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



