PSUtils::Result PSUtils::DoSampling()

in lingvo/tasks/car/ops/ps_utils.cc [250:454]


PSUtils::Result PSUtils::DoSampling(const Tensor& points,
                                    const Tensor& points_padding,
                                    const int32 num_seeded_points) const {
  // Points must be of rank 3, and padding must be a matrix.
  DCHECK_EQ(points.dims(), 3);
  DCHECK_EQ(points_padding.dims(), 2);
  // 3D points.
  DCHECK_EQ(points.dim_size(2), 3);
  DCHECK_EQ(points.dim_size(0), points_padding.dim_size(0));
  DCHECK_EQ(points.dim_size(1), points_padding.dim_size(1));

  auto points_t = points.tensor<float, 3>();
  auto points_padding_t = points_padding.matrix<float>();
  const int64_t batch_size = points.dim_size(0);
  const int64_t num_points = points.dim_size(1);

  Result result;
  result.center = Tensor(DT_INT32, {batch_size, opts_.num_centers});
  result.center_padding = Tensor(DT_FLOAT, {batch_size, opts_.num_centers});
  result.indices =
      Tensor(DT_INT32, {batch_size, opts_.num_centers, opts_.num_neighbors});
  result.indices_padding =
      Tensor(DT_FLOAT, {batch_size, opts_.num_centers, opts_.num_neighbors});

  auto center_t = result.center.matrix<int32>();
  center_t.setConstant(0);
  auto center_padding_t = result.center_padding.matrix<float>();
  auto indices_t = result.indices.tensor<int32, 3>();
  indices_t.setConstant(0);
  auto padding_t = result.indices_padding.tensor<float, 3>();
  padding_t.setConstant(1.0);

  // Max distance squared as the threshold.
  const float threshold = Square(opts_.max_dist);

  // The idea behind the hash lookup is to only do neighbor / distance checks
  // for plausibly close neighbors, rather than looking at all points for each
  // center.  We do this by gridifying the points, and for each center only
  // looking at points in nearby grid cells.  This a cheap version of a more
  // sophisticated algorithm like using a KDTree or RangeTree.
  const bool use_hash_lookup =
      (opts_.neighbor_search_algorithm == PSUtils::Options::N_HASH);

  for (int cur_batch = 0; cur_batch < batch_size; ++cur_batch) {
    std::vector<bool> candidates(num_points);

    float xmin = std::numeric_limits<float>::max();
    float ymin = std::numeric_limits<float>::max();
    float zmin = std::numeric_limits<float>::max();
    float xmax = std::numeric_limits<float>::lowest();
    float ymax = std::numeric_limits<float>::lowest();
    float zmax = std::numeric_limits<float>::lowest();

    for (int i = 0; i < num_points; ++i) {
      // The first num_seeded_points are not candidates of the selector, because
      // they are always selected.
      candidates[i] =
          (i >= num_seeded_points && points_padding_t(cur_batch, i) == 0.0) &&
          (opts_.center_z_min <= points_t(cur_batch, i, 2)) &&
          (points_t(cur_batch, i, 2) <= opts_.center_z_max);

      // Find min / max points for computing grid buckets.
      if (use_hash_lookup && points_padding_t(cur_batch, i) == 0.0) {
        xmin = std::min(points_t(cur_batch, i, 0), xmin);
        xmax = std::max(points_t(cur_batch, i, 0), xmax);
        ymin = std::min(points_t(cur_batch, i, 1), ymin);
        ymax = std::max(points_t(cur_batch, i, 1), ymax);
        zmin = std::min(points_t(cur_batch, i, 2), zmin);
        zmax = std::max(points_t(cur_batch, i, 2), zmax);
      }
    }

    // Stores a mapping of bucket_id -> list of point indices.  The buckets are
    // the voxelized breakdown of the 3D space and points fall into these
    // voxels.  The length of the cube is the max_distance.
    std::vector<std::vector<int>> buckets_vec;
    std::vector<std::vector<float>> buckets_values;

    int x_intervals = 0;
    int y_intervals = 0;
    int z_intervals = 0;

    if (use_hash_lookup) {
      // Adjust boundaries to avoid edge conditions.  We use max_dist as a
      // conservative estimate.
      xmin -= opts_.max_dist;
      ymin -= opts_.max_dist;
      zmin -= opts_.max_dist;
      xmax += opts_.max_dist;
      ymax += opts_.max_dist;
      zmax += opts_.max_dist;

      x_intervals = std::ceil((xmax - xmin) / opts_.max_dist);
      y_intervals = std::ceil((ymax - ymin) / opts_.max_dist);
      z_intervals = std::ceil((zmax - zmin) / opts_.max_dist);

      // The number of buckets is the product of all the intervals.
      buckets_vec.resize(x_intervals * y_intervals * z_intervals);
      for (int i = 0; i < num_points; ++i) {
        // Compute which bucket each valid point falls into.
        //
        // A valid is a non-padded, non-seeded point.
        if (points_padding_t(cur_batch, i) == 0.0 && i >= num_seeded_points) {
          int bucket_x =
              FindBucket(points_t(cur_batch, i, 0), xmin, opts_.max_dist);
          int bucket_y =
              FindBucket(points_t(cur_batch, i, 1), ymin, opts_.max_dist);
          int bucket_z =
              FindBucket(points_t(cur_batch, i, 2), zmin, opts_.max_dist);
          if (bucket_x >= 0 && bucket_x < x_intervals && bucket_y >= 0 &&
              bucket_y < y_intervals && bucket_z >= 0 &&
              bucket_z < z_intervals) {
            // Compute the linearized bucket offset.
            auto bucket_id = BucketId(bucket_x, bucket_y, bucket_z, y_intervals,
                                      z_intervals);
            buckets_vec[bucket_id].push_back(i);
          }
        }
      }
    }

    Selector selector(candidates, Seed());
    Sampler sampler(opts_.num_neighbors, Seed());

    for (int i = 0; i < opts_.num_centers; ++i) {
      // Pick a point as i-th center.
      int k;
      if (i < num_seeded_points) {
        k = i;
      } else {
        // Pick a point as i-th center.
        k = selector.Get();
      }

      if (k < 0) {
        center_padding_t(cur_batch, i) = 1.0;
        continue;
      }
      center_padding_t(cur_batch, i) = 0.0;
      center_t(cur_batch, i) = k;

      // Goes through all *non-seeded* points. If j-th point is within a radius
      // of center, adds it to the sampler.
      sampler.Reset();

      std::vector<int> neighbor_idx;
      if (use_hash_lookup) {
        // For each center, compute the bucket it is in.
        int bucket_x =
            FindBucket(points_t(cur_batch, k, 0), xmin, opts_.max_dist);
        int bucket_y =
            FindBucket(points_t(cur_batch, k, 1), ymin, opts_.max_dist);
        int bucket_z =
            FindBucket(points_t(cur_batch, k, 2), zmin, opts_.max_dist);

        // Iterate over 3x3x3 buckets centered at [bucket_x, bucket_y, bucket_z]
        //
        // Extract the neighborhood indices from there.
        for (int bx = bucket_x - 1; bx <= bucket_x + 1; ++bx) {
          if (bx < 0 || bx >= x_intervals) continue;
          for (int by = bucket_y - 1; by <= bucket_y + 1; ++by) {
            if (by < 0 || by >= y_intervals) continue;
            for (int bz = bucket_z - 1; bz <= bucket_z + 1; ++bz) {
              if (bz < 0 || bz >= z_intervals) continue;
              auto bucket_id = BucketId(bx, by, bz, y_intervals, z_intervals);
              auto bucket_indices = buckets_vec[bucket_id];
              neighbor_idx.insert(neighbor_idx.end(), bucket_indices.begin(),
                                  bucket_indices.end());
            }
          }
        }

      } else {
        neighbor_idx.reserve(num_points);
        for (int j = num_seeded_points; j < num_points; ++j) {
          if (points_padding_t(cur_batch, j) == 0.0) {
            neighbor_idx.push_back(j);
          }
        }
      }

      // Iterate over all neighbor indices.
      for (int j : neighbor_idx) {
        auto ss_xy =
            Square(points_t(cur_batch, k, 0) - points_t(cur_batch, j, 0)) +
            Square(points_t(cur_batch, k, 1) - points_t(cur_batch, j, 1));
        auto z = points_t(cur_batch, j, 2);
        auto ss_xyz = ss_xy + Square(points_t(cur_batch, k, 2) - z);
        if (ss_xyz <= threshold) {
          sampler.Add(j, ss_xyz);
        }
        selector.Update(j, ss_xy);
      }
      auto ids = sampler.Get();
      CHECK_LE(0, ids.size());
      CHECK_LE(ids.size(), opts_.num_neighbors);
      for (int j = 0; j < ids.size(); ++j) {
        indices_t(cur_batch, i, j) = ids[j].id;
        padding_t(cur_batch, i, j) = 0.0f;
      }
    }
  }

  return result;
}