in src/operator/contrib/multibox_target.cc [72:280]
inline void MultiBoxTargetForward(const Tensor<cpu, 2, DType> &loc_target,
const Tensor<cpu, 2, DType> &loc_mask,
const Tensor<cpu, 2, DType> &cls_target,
const Tensor<cpu, 2, DType> &anchors,
const Tensor<cpu, 3, DType> &labels,
const Tensor<cpu, 3, DType> &cls_preds,
const Tensor<cpu, 4, DType> &temp_space,
const float overlap_threshold,
const float background_label,
const float negative_mining_ratio,
const float negative_mining_thresh,
const int minimum_negative_samples,
const nnvm::Tuple<float> &variances) {
const DType *p_anchor = anchors.dptr_;
const int num_batches = labels.size(0);
const int num_labels = labels.size(1);
const int label_width = labels.size(2);
const int num_anchors = anchors.size(0);
CHECK_EQ(variances.ndim(), 4);
for (int nbatch = 0; nbatch < num_batches; ++nbatch) {
const DType *p_label = labels.dptr_ + nbatch * num_labels * label_width;
const DType *p_overlaps = temp_space.dptr_ + nbatch * num_anchors * num_labels;
int num_valid_gt = 0;
for (int i = 0; i < num_labels; ++i) {
if (static_cast<float>(*(p_label + i * label_width)) == -1.0f) {
CHECK_EQ(static_cast<float>(*(p_label + i * label_width + 1)), -1.0f);
CHECK_EQ(static_cast<float>(*(p_label + i * label_width + 2)), -1.0f);
CHECK_EQ(static_cast<float>(*(p_label + i * label_width + 3)), -1.0f);
CHECK_EQ(static_cast<float>(*(p_label + i * label_width + 4)), -1.0f);
break;
}
++num_valid_gt;
} // end iterate labels
if (num_valid_gt > 0) {
std::vector<bool> gt_flags(num_valid_gt, false);
std::vector<std::pair<float, int>> max_matches(num_anchors,
std::pair<float, int>(-1.0f, -1));
std::vector<char> anchor_flags(num_anchors, -1); // -1 means don't care
int num_positive = 0;
while (std::find(gt_flags.begin(), gt_flags.end(), false) != gt_flags.end()) {
// ground-truths not fully matched
int best_anchor = -1;
int best_gt = -1;
float max_overlap = 1e-6; // start with a very small positive overlap
for (int j = 0; j < num_anchors; ++j) {
if (anchor_flags[j] == 1) {
continue; // already matched this anchor
}
const DType *pp_overlaps = p_overlaps + j * num_labels;
for (int k = 0; k < num_valid_gt; ++k) {
if (gt_flags[k]) {
continue; // already matched this gt
}
float iou = static_cast<float>(*(pp_overlaps + k));
if (iou > max_overlap) {
best_anchor = j;
best_gt = k;
max_overlap = iou;
}
}
}
if (best_anchor == -1) {
CHECK_EQ(best_gt, -1);
break; // no more good match
} else {
CHECK_EQ(max_matches[best_anchor].first, -1.0f);
CHECK_EQ(max_matches[best_anchor].second, -1);
max_matches[best_anchor].first = max_overlap;
max_matches[best_anchor].second = best_gt;
num_positive += 1;
// mark as visited
gt_flags[best_gt] = true;
anchor_flags[best_anchor] = 1;
}
} // end while
if (overlap_threshold > 0) {
// find positive matches based on overlaps
for (int j = 0; j < num_anchors; ++j) {
if (anchor_flags[j] == 1) {
continue; // already matched this anchor
}
const DType *pp_overlaps = p_overlaps + j * num_labels;
int best_gt = -1;
float max_iou = -1.0f;
for (int k = 0; k < num_valid_gt; ++k) {
float iou = static_cast<float>(*(pp_overlaps + k));
if (iou > max_iou) {
best_gt = k;
max_iou = iou;
}
}
if (best_gt != -1) {
CHECK_EQ(max_matches[j].first, -1.0f);
CHECK_EQ(max_matches[j].second, -1);
max_matches[j].first = max_iou;
max_matches[j].second = best_gt;
if (max_iou > overlap_threshold) {
num_positive += 1;
// mark as visited
gt_flags[best_gt] = true;
anchor_flags[j] = 1;
}
}
} // end iterate anchors
}
if (negative_mining_ratio > 0) {
const int num_classes = cls_preds.size(1);
DType *p_cls_preds = cls_preds.dptr_ + nbatch * num_classes * num_anchors;
CHECK_GT(negative_mining_thresh, 0);
int num_negative = num_positive * negative_mining_ratio;
if (num_negative > (num_anchors - num_positive)) {
num_negative = num_anchors - num_positive;
}
if (num_negative > 0) {
// use negative mining, pick up "best" negative samples
std::vector<SortElemDescend> temp;
temp.reserve(num_anchors - num_positive);
for (int j = 0; j < num_anchors; ++j) {
if (anchor_flags[j] == 1) {
continue; // already matched this anchor
}
if (max_matches[j].first < 0) {
// not yet calculated
const DType *pp_overlaps = p_overlaps + j * num_labels;
int best_gt = -1;
float max_iou = -1.0f;
for (int k = 0; k < num_valid_gt; ++k) {
float iou = static_cast<float>(*(pp_overlaps + k));
if (iou > max_iou) {
best_gt = k;
max_iou = iou;
}
}
if (best_gt != -1) {
CHECK_EQ(max_matches[j].first, -1.0f);
CHECK_EQ(max_matches[j].second, -1);
max_matches[j].first = max_iou;
max_matches[j].second = best_gt;
}
}
if (max_matches[j].first < negative_mining_thresh &&
anchor_flags[j] == -1) {
// calcuate class predictions
DType max_val = p_cls_preds[j];
for (int k = 1; k < num_classes; ++k) {
DType tmp = p_cls_preds[j + num_anchors * k];
if (tmp > max_val) max_val = tmp;
}
DType sum = 0.f;
for (int k = 0; k < num_classes; ++k) {
DType tmp = p_cls_preds[j + num_anchors * k];
sum += std::exp(tmp - max_val);
}
DType prob = std::exp(p_cls_preds[j] - max_val) / sum;
// loss should be -log(x), but value does not matter, skip log
temp.push_back(SortElemDescend(-prob, j));
}
} // end iterate anchors
CHECK_GE(temp.size(), num_negative);
std::stable_sort(temp.begin(), temp.end());
for (int i = 0; i < num_negative; ++i) {
anchor_flags[temp[i].index] = 0; // mark as negative sample
}
}
} else {
// use all negative samples
for (int i = 0; i < num_anchors; ++i) {
if (anchor_flags[i] != 1) {
anchor_flags[i] = 0;
}
}
}
// assign training targets
DType *p_loc_target = loc_target.dptr_ + nbatch * num_anchors * 4;
DType *p_loc_mask = loc_mask.dptr_ + nbatch * num_anchors * 4;
DType *p_cls_target = cls_target.dptr_ + nbatch * num_anchors;
for (int i = 0; i < num_anchors; ++i) {
if (anchor_flags[i] == 1) {
// positive sample
CHECK_GE(max_matches[i].second, 0);
// 0 reserved for background
*(p_cls_target + i) = *(p_label + label_width * max_matches[i].second) + 1;
int offset = i * 4;
*(p_loc_mask + offset) = 1;
*(p_loc_mask + offset + 1) = 1;
*(p_loc_mask + offset + 2) = 1;
*(p_loc_mask + offset + 3) = 1;
AssignLocTargets(p_anchor + i * 4,
p_label + label_width * max_matches[i].second + 1, p_loc_target + offset,
variances[0], variances[1], variances[2], variances[3]);
} else if (anchor_flags[i] == 0) {
// negative sample
*(p_cls_target + i) = 0;
int offset = i * 4;
*(p_loc_mask + offset) = 0;
*(p_loc_mask + offset + 1) = 0;
*(p_loc_mask + offset + 2) = 0;
*(p_loc_mask + offset + 3) = 0;
}
} // end iterate anchors
}
} // end iterate batches
}