void octree_dropout()

in octree/octree/transform_octree.cpp [286:431]


void octree_dropout(vector<char>& octree_output, const string& octree_input,
    const int depth_dropout, const float threshold) {
  // generate the drop flag
  OctreeParser parser_in;
  parser_in.set_cpu(octree_input.c_str());
  int depth = parser_in.info().depth();
  vector<vector<uintk>> drop(depth + 1);
  // generate random flag for the level depth_dropout
  int nnum_d = parser_in.info().node_num(depth_dropout);
  drop[depth_dropout].resize(nnum_d, 0);
  std::default_random_engine generator(static_cast<unsigned>(time(nullptr)));
  std::bernoulli_distribution distribution(threshold);
  for (int i = 0; i < nnum_d; ++i) {
    drop[depth_dropout][i] = static_cast<unsigned>(distribution(generator));
  }
  for (int d = depth_dropout + 1; d <= depth; ++d) {
    int nnum_d = parser_in.info().node_num(d);
    int nnum_dp = parser_in.info().node_num(d - 1);
    const int* children_dp = parser_in.children_cpu(d - 1);
    drop[d].resize(nnum_d);
    for (int i = 0; i < nnum_dp; ++i) {
      int t = children_dp[i];
      if (t < 0) continue;  // continue if it has no children
      // assign the drop flag of a parent node to its children
      for (int j = 0; j < 8; ++j) {
        drop[d][t * 8 + j] = drop[d - 1][i];
      }
    }
  }

  // init output
  OctreeInfo info_output = parser_in.info();
  vector<int> node_num(depth + 1, 0);
  for (int d = 0; d <= depth; ++d) {
    if (d <= depth_dropout) {
      node_num[d] = parser_in.info().node_num(d);
    } else {
      int num = 0;
      for (auto v : drop[d]) {
        if (v == 0) num++;
      }
      node_num[d] = num;
    }
  }
  info_output.set_nnum(node_num.data());
  info_output.set_nnum_cum();
  info_output.set_ptr_dis();
  octree_output.resize(info_output.sizeof_octree());
  OctreeParser parser_out;
  parser_out.set_cpu(octree_output.data(), &info_output);

  // start dropout
  // from level 0 to depth_output
  int num = parser_in.info().node_num_cum(depth_dropout + 1);
  int channel_key = parser_in.info().channel(OctreeInfo::kKey);
  // CHECK_EQ(channel_key, 1) << "Currently the channel must be 1";
  std::copy_n(parser_in.key_cpu(0), num * channel_key, parser_out.mutable_key_cpu(0));
  std::copy_n(parser_in.children_cpu(0), num, parser_out.mutable_children_cpu(0));
  int channel_feature = parser_in.info().channel(OctreeInfo::kFeature);
  int location_feature = parser_in.info().locations(OctreeInfo::kFeature);
  if (location_feature == -1) {
    std::copy_n(parser_in.feature_cpu(0), num * channel_feature,
        parser_out.mutable_feature_cpu(0));
  }
  int channel_label = parser_in.info().channel(OctreeInfo::kLabel);
  int location_label = parser_in.info().locations(OctreeInfo::kLabel);
  if (location_label == -1) {
    std::copy_n(parser_in.label_cpu(0), num * channel_label,
        parser_out.mutable_label_cpu(0));
  }
  int channel_split = parser_in.info().channel(OctreeInfo::kSplit);
  int location_split = parser_in.info().locations(OctreeInfo::kSplit);
  if (location_split == -1) {
    std::copy_n(parser_in.split_cpu(0), num * channel_split,
        parser_out.mutable_split_cpu(0));
  }

  // from level depth_output+1 to depth
  vector<int> node_num_nempty(depth + 1, 0);
  for (int d = depth_dropout + 1; d <= depth; ++d) {
    int nnum_d = parser_in.info().node_num(d), id = 0;
    const int* child_src = parser_in.children_cpu(d);
    const uintk* key_src = parser_in.key_cpu(d);
    int* child_des = parser_out.mutable_children_cpu(d);
    uintk* key_des = parser_out.mutable_key_cpu(d);
    for (int i = 0, j = 0; i < nnum_d; ++i) {
      if (drop[d][i] == 0) {
        key_des[j] = key_src[i];
        int ch = child_src[i] < 0 ? child_src[i] : id++;
        child_des[j] = ch;
        ++j;
      }
    }
    node_num_nempty[d] = id;

    if (location_feature == -1 || d == depth) {
      int nnum_src = parser_out.info().node_num(d);
      const float* feature_src = parser_in.feature_cpu(d);
      float* feature_des = parser_out.mutable_feature_cpu(d);
      for (int i = 0, j = 0; i < nnum_d; ++i) {
        if (drop[d][i] == 0) {
          for (int c = 0; c < channel_feature; ++c) {
            feature_des[c * nnum_src + j] = feature_src[c * nnum_d + i];
          }
          ++j;
        }
      }
    }

    if ((location_label == -1 || d == depth) && channel_label != 0) {
      const float* label_src = parser_in.label_cpu(d);
      float* label_des = parser_out.mutable_label_cpu(d);
      for (int i = 0, j = 0; i < nnum_d; ++i) {
        if (drop[d][i] == 0) {
          label_des[j] = label_src[i];
          ++j;
        }
      }
    }

    if ((location_split == -1 || d == depth) && channel_split != 0) {
      const float* split_src = parser_in.split_cpu(d);
      float* split_des = parser_out.mutable_split_cpu(d);
      for (int i = 0, j = 0; i < nnum_d; ++i) {
        if (drop[d][i] == 0) {
          split_des[j] = split_src[i];
          ++j;
        }
      }
    }
  }

  // modify the children and node_num_nempty
  int id = 0;
  const int* child_src = parser_in.children_cpu(depth_dropout);
  int* child_des = parser_out.mutable_children_cpu(depth_dropout);
  for (int i = 0; i < node_num[depth_dropout]; ++i) {
    child_des[i] =
        (drop[depth_dropout][i] == 1 || child_src[i] < 0) ? child_src[i] : id++;
  }
  for (int d = 0; d < depth_dropout; ++d) {
    node_num_nempty[d] = parser_in.info().node_num_nempty(d);
  }
  node_num_nempty[depth_dropout] = id;  // !!! important
  parser_out.mutable_info().set_nempty(node_num_nempty.data());
}