in octree/octree/transform_octree.cpp [286:431]
void octree_dropout(vector<char>& octree_output, const string& octree_input,
const int depth_dropout, const float threshold) {
// generate the drop flag
OctreeParser parser_in;
parser_in.set_cpu(octree_input.c_str());
int depth = parser_in.info().depth();
vector<vector<uintk>> drop(depth + 1);
// generate random flag for the level depth_dropout
int nnum_d = parser_in.info().node_num(depth_dropout);
drop[depth_dropout].resize(nnum_d, 0);
std::default_random_engine generator(static_cast<unsigned>(time(nullptr)));
std::bernoulli_distribution distribution(threshold);
for (int i = 0; i < nnum_d; ++i) {
drop[depth_dropout][i] = static_cast<unsigned>(distribution(generator));
}
for (int d = depth_dropout + 1; d <= depth; ++d) {
int nnum_d = parser_in.info().node_num(d);
int nnum_dp = parser_in.info().node_num(d - 1);
const int* children_dp = parser_in.children_cpu(d - 1);
drop[d].resize(nnum_d);
for (int i = 0; i < nnum_dp; ++i) {
int t = children_dp[i];
if (t < 0) continue; // continue if it has no children
// assign the drop flag of a parent node to its children
for (int j = 0; j < 8; ++j) {
drop[d][t * 8 + j] = drop[d - 1][i];
}
}
}
// init output
OctreeInfo info_output = parser_in.info();
vector<int> node_num(depth + 1, 0);
for (int d = 0; d <= depth; ++d) {
if (d <= depth_dropout) {
node_num[d] = parser_in.info().node_num(d);
} else {
int num = 0;
for (auto v : drop[d]) {
if (v == 0) num++;
}
node_num[d] = num;
}
}
info_output.set_nnum(node_num.data());
info_output.set_nnum_cum();
info_output.set_ptr_dis();
octree_output.resize(info_output.sizeof_octree());
OctreeParser parser_out;
parser_out.set_cpu(octree_output.data(), &info_output);
// start dropout
// from level 0 to depth_output
int num = parser_in.info().node_num_cum(depth_dropout + 1);
int channel_key = parser_in.info().channel(OctreeInfo::kKey);
// CHECK_EQ(channel_key, 1) << "Currently the channel must be 1";
std::copy_n(parser_in.key_cpu(0), num * channel_key, parser_out.mutable_key_cpu(0));
std::copy_n(parser_in.children_cpu(0), num, parser_out.mutable_children_cpu(0));
int channel_feature = parser_in.info().channel(OctreeInfo::kFeature);
int location_feature = parser_in.info().locations(OctreeInfo::kFeature);
if (location_feature == -1) {
std::copy_n(parser_in.feature_cpu(0), num * channel_feature,
parser_out.mutable_feature_cpu(0));
}
int channel_label = parser_in.info().channel(OctreeInfo::kLabel);
int location_label = parser_in.info().locations(OctreeInfo::kLabel);
if (location_label == -1) {
std::copy_n(parser_in.label_cpu(0), num * channel_label,
parser_out.mutable_label_cpu(0));
}
int channel_split = parser_in.info().channel(OctreeInfo::kSplit);
int location_split = parser_in.info().locations(OctreeInfo::kSplit);
if (location_split == -1) {
std::copy_n(parser_in.split_cpu(0), num * channel_split,
parser_out.mutable_split_cpu(0));
}
// from level depth_output+1 to depth
vector<int> node_num_nempty(depth + 1, 0);
for (int d = depth_dropout + 1; d <= depth; ++d) {
int nnum_d = parser_in.info().node_num(d), id = 0;
const int* child_src = parser_in.children_cpu(d);
const uintk* key_src = parser_in.key_cpu(d);
int* child_des = parser_out.mutable_children_cpu(d);
uintk* key_des = parser_out.mutable_key_cpu(d);
for (int i = 0, j = 0; i < nnum_d; ++i) {
if (drop[d][i] == 0) {
key_des[j] = key_src[i];
int ch = child_src[i] < 0 ? child_src[i] : id++;
child_des[j] = ch;
++j;
}
}
node_num_nempty[d] = id;
if (location_feature == -1 || d == depth) {
int nnum_src = parser_out.info().node_num(d);
const float* feature_src = parser_in.feature_cpu(d);
float* feature_des = parser_out.mutable_feature_cpu(d);
for (int i = 0, j = 0; i < nnum_d; ++i) {
if (drop[d][i] == 0) {
for (int c = 0; c < channel_feature; ++c) {
feature_des[c * nnum_src + j] = feature_src[c * nnum_d + i];
}
++j;
}
}
}
if ((location_label == -1 || d == depth) && channel_label != 0) {
const float* label_src = parser_in.label_cpu(d);
float* label_des = parser_out.mutable_label_cpu(d);
for (int i = 0, j = 0; i < nnum_d; ++i) {
if (drop[d][i] == 0) {
label_des[j] = label_src[i];
++j;
}
}
}
if ((location_split == -1 || d == depth) && channel_split != 0) {
const float* split_src = parser_in.split_cpu(d);
float* split_des = parser_out.mutable_split_cpu(d);
for (int i = 0, j = 0; i < nnum_d; ++i) {
if (drop[d][i] == 0) {
split_des[j] = split_src[i];
++j;
}
}
}
}
// modify the children and node_num_nempty
int id = 0;
const int* child_src = parser_in.children_cpu(depth_dropout);
int* child_des = parser_out.mutable_children_cpu(depth_dropout);
for (int i = 0; i < node_num[depth_dropout]; ++i) {
child_des[i] =
(drop[depth_dropout][i] == 1 || child_src[i] < 0) ? child_src[i] : id++;
}
for (int d = 0; d < depth_dropout; ++d) {
node_num_nempty[d] = parser_in.info().node_num_nempty(d);
}
node_num_nempty[depth_dropout] = id; // !!! important
parser_out.mutable_info().set_nempty(node_num_nempty.data());
}