def initialize_coords()

in models/trunks/spconv/models/conditional_random_fields.py [0:0]


  def initialize_coords(self, model, in_coords, in_color):
    if torch.prod(convert_to_int_tensor(model.OUT_PIXEL_DIST, model.D)) != 1:
      self.requires_mapping = True

      out_coords = model.get_coords(model.OUT_PIXEL_DIST)
      out_color = model.permute_feature(in_color, model.OUT_PIXEL_DIST).int()

      # Tri/Bi-lateral grid
      out_tri_coords = torch.cat(
          [
              (torch.floor(out_coords[:, :3].float() / self.spatial_sigma)).int(),
              (torch.floor(out_color.float() / self.chromatic_sigma)).int(),
              out_coords[:, 3:]  # (time and) batch
          ],
          dim=1)
      orig_tri_coords = torch.cat(
          [
              (torch.floor(in_coords[:, :3].float() / self.spatial_sigma)).int(),
              (torch.floor(in_color.float() / self.chromatic_sigma)).int(),
              in_coords[:, 3:]  # (time and) batch
          ],
          dim=1)

      crf_tri_coords = torch.cat((out_tri_coords, orig_tri_coords), dim=0)

      # Create a trilateral Grid
      # super(MeanField, self).initialize_coords_with_duplicates(crf_tri_coords)

      # Create Sparse matrix mappings to/from the CRF coords
      in_cols = self.get_index_map(out_tri_coords, 1)
      self.in_mapping = torch.sparse.FloatTensor(
          torch.stack((in_cols.long(), torch.arange(in_cols.size(0), out=torch.LongTensor()))),
          torch.ones(in_cols.size(0)), torch.Size((self.n_rows, in_cols.size(0))))

      out_cols = self.get_index_map(orig_tri_coords, 1)
      self.out_mapping = torch.sparse.FloatTensor(
          torch.stack((torch.arange(out_cols.size(0), out=torch.LongTensor()), out_cols.long())),
          torch.ones(out_cols.size(0)), torch.Size((out_cols.size(0), self.n_rows)))

      if self.config.is_cuda:
        self.in_mapping, self.out_mapping = self.in_mapping.cuda(), self.out_mapping.cuda()

    else:
      self.requires_mapping = False

      out_coords = in_coords
      out_color = in_color
      crf_tri_coords = torch.cat(
          [
              (torch.floor(in_coords[:, :3].float() / self.spatial_sigma)).int(),
              (torch.floor(in_color.float() / self.chromatic_sigma)).int(),
              in_coords[:, 3:],  # (time and) batch
          ],
          dim=1)

    return crf_tri_coords