def get_locations()

in siammot/modelling/track_head/EMM/track_core.py [0:0]


def get_locations(fmap: torch.Tensor, template_fmap: torch.Tensor,
                  sr_boxes: [BoxList], shift_xy, up_scale=1):
    """

    """
    h, w = fmap.size()[-2:]
    h, w = h*up_scale, w*up_scale
    concat_boxes = cat([b.bbox for b in sr_boxes], dim=0)
    box_w = concat_boxes[:, 2] - concat_boxes[:, 0]
    box_h = concat_boxes[:, 3] - concat_boxes[:, 1]
    stride_h = box_h / (h - 1)
    stride_w = box_w / (w - 1)

    device = concat_boxes.device
    delta_x = torch.arange(0, w, dtype=torch.float32, device=device)
    delta_y = torch.arange(0, h, dtype=torch.float32, device=device)

    delta_x = (concat_boxes[:, 0])[:, None] + delta_x[None, :] * stride_w[:, None]
    delta_y = (concat_boxes[:, 1])[:, None] + delta_y[None, :] * stride_h[:, None]

    h0, w0 = template_fmap.shape[-2:]
    assert (h0 == w0)
    border = np.int(np.floor(h0 / 2))
    st_end_idx = int(border * up_scale)
    delta_x = delta_x[:, st_end_idx:-st_end_idx]
    delta_y = delta_y[:, st_end_idx:-st_end_idx]

    locations = []
    num_boxes = delta_x.shape[0]
    for i in range(num_boxes):
        _y, _x = torch.meshgrid((delta_y[i, :], delta_x[i, :]))
        _y = _y.reshape(-1)
        _x = _x.reshape(-1)
        _xy = torch.stack((_x, _y), dim=1)
        locations.append(_xy)
    locations = torch.stack(locations)

    # shift the coordinates w.r.t the original image space (before padding)
    locations[:, :, 0] -= shift_xy[0]
    locations[:, :, 1] -= shift_xy[1]

    return locations