inline void MultiBoxTargetForward()

in src/operator/contrib/multibox_target.cc [72:280]


inline void MultiBoxTargetForward(const Tensor<cpu, 2, DType> &loc_target,
                           const Tensor<cpu, 2, DType> &loc_mask,
                           const Tensor<cpu, 2, DType> &cls_target,
                           const Tensor<cpu, 2, DType> &anchors,
                           const Tensor<cpu, 3, DType> &labels,
                           const Tensor<cpu, 3, DType> &cls_preds,
                           const Tensor<cpu, 4, DType> &temp_space,
                           const float overlap_threshold,
                           const float background_label,
                           const float negative_mining_ratio,
                           const float negative_mining_thresh,
                           const int minimum_negative_samples,
                           const nnvm::Tuple<float> &variances) {
  const DType *p_anchor = anchors.dptr_;
  const int num_batches = labels.size(0);
  const int num_labels = labels.size(1);
  const int label_width = labels.size(2);
  const int num_anchors = anchors.size(0);
  CHECK_EQ(variances.ndim(), 4);
  for (int nbatch = 0; nbatch < num_batches; ++nbatch) {
    const DType *p_label = labels.dptr_ + nbatch * num_labels * label_width;
    const DType *p_overlaps = temp_space.dptr_ + nbatch * num_anchors * num_labels;
    int num_valid_gt = 0;
    for (int i = 0; i < num_labels; ++i) {
      if (static_cast<float>(*(p_label + i * label_width)) == -1.0f) {
        CHECK_EQ(static_cast<float>(*(p_label + i * label_width + 1)), -1.0f);
        CHECK_EQ(static_cast<float>(*(p_label + i * label_width + 2)), -1.0f);
        CHECK_EQ(static_cast<float>(*(p_label + i * label_width + 3)), -1.0f);
        CHECK_EQ(static_cast<float>(*(p_label + i * label_width + 4)), -1.0f);
        break;
      }
      ++num_valid_gt;
    }  // end iterate labels

    if (num_valid_gt > 0) {
      std::vector<bool> gt_flags(num_valid_gt, false);
      std::vector<std::pair<float, int>> max_matches(num_anchors,
        std::pair<float, int>(-1.0f, -1));
      std::vector<char> anchor_flags(num_anchors, -1);  // -1 means don't care
      int num_positive = 0;
      while (std::find(gt_flags.begin(), gt_flags.end(), false) != gt_flags.end()) {
        // ground-truths not fully matched
        int best_anchor = -1;
        int best_gt = -1;
        float max_overlap = 1e-6;  // start with a very small positive overlap
        for (int j = 0; j < num_anchors; ++j) {
          if (anchor_flags[j] == 1) {
            continue;  // already matched this anchor
          }
          const DType *pp_overlaps = p_overlaps + j * num_labels;
          for (int k = 0; k < num_valid_gt; ++k) {
            if (gt_flags[k]) {
              continue;  // already matched this gt
            }
            float iou = static_cast<float>(*(pp_overlaps + k));
            if (iou > max_overlap) {
              best_anchor = j;
              best_gt = k;
              max_overlap = iou;
            }
          }
        }

        if (best_anchor == -1) {
          CHECK_EQ(best_gt, -1);
          break;  // no more good match
        } else {
          CHECK_EQ(max_matches[best_anchor].first, -1.0f);
          CHECK_EQ(max_matches[best_anchor].second, -1);
          max_matches[best_anchor].first = max_overlap;
          max_matches[best_anchor].second = best_gt;
          num_positive += 1;
          // mark as visited
          gt_flags[best_gt] = true;
          anchor_flags[best_anchor] = 1;
        }
      }  // end while

      if (overlap_threshold > 0) {
        // find positive matches based on overlaps
        for (int j = 0; j < num_anchors; ++j) {
          if (anchor_flags[j] == 1) {
            continue;  // already matched this anchor
          }
          const DType *pp_overlaps = p_overlaps + j * num_labels;
          int best_gt = -1;
          float max_iou = -1.0f;
          for (int k = 0; k < num_valid_gt; ++k) {
            float iou = static_cast<float>(*(pp_overlaps + k));
            if (iou > max_iou) {
              best_gt = k;
              max_iou = iou;
            }
          }
          if (best_gt != -1) {
            CHECK_EQ(max_matches[j].first, -1.0f);
            CHECK_EQ(max_matches[j].second, -1);
            max_matches[j].first = max_iou;
            max_matches[j].second = best_gt;
            if (max_iou > overlap_threshold) {
              num_positive += 1;
              // mark as visited
              gt_flags[best_gt] = true;
              anchor_flags[j] = 1;
            }
          }
        }  // end iterate anchors
      }

      if (negative_mining_ratio > 0) {
        const int num_classes = cls_preds.size(1);
        DType *p_cls_preds = cls_preds.dptr_ + nbatch * num_classes * num_anchors;
        CHECK_GT(negative_mining_thresh, 0);
        int num_negative = num_positive * negative_mining_ratio;
        if (num_negative > (num_anchors - num_positive)) {
          num_negative = num_anchors - num_positive;
        }
        if (num_negative > 0) {
          // use negative mining, pick up "best" negative samples
          std::vector<SortElemDescend> temp;
          temp.reserve(num_anchors - num_positive);
          for (int j = 0; j < num_anchors; ++j) {
            if (anchor_flags[j] == 1) {
              continue;  // already matched this anchor
            }
            if (max_matches[j].first < 0) {
              // not yet calculated
              const DType *pp_overlaps = p_overlaps + j * num_labels;
              int best_gt = -1;
              float max_iou = -1.0f;
              for (int k = 0; k < num_valid_gt; ++k) {
                float iou = static_cast<float>(*(pp_overlaps + k));
                if (iou > max_iou) {
                  best_gt = k;
                  max_iou = iou;
                }
              }
              if (best_gt != -1) {
                CHECK_EQ(max_matches[j].first, -1.0f);
                CHECK_EQ(max_matches[j].second, -1);
                max_matches[j].first = max_iou;
                max_matches[j].second = best_gt;
              }
            }
            if (max_matches[j].first < negative_mining_thresh &&
                anchor_flags[j] == -1) {
                // calcuate class predictions
              DType max_val = p_cls_preds[j];
              for (int k = 1; k < num_classes; ++k) {
                DType tmp = p_cls_preds[j + num_anchors * k];
                if (tmp > max_val) max_val = tmp;
              }
              DType sum = 0.f;
              for (int k = 0; k < num_classes; ++k) {
                DType tmp = p_cls_preds[j + num_anchors * k];
                sum += std::exp(tmp - max_val);
              }
              DType prob = std::exp(p_cls_preds[j] - max_val) / sum;
              // loss should be -log(x), but value does not matter, skip log
              temp.push_back(SortElemDescend(-prob, j));
            }
          }  // end iterate anchors

          CHECK_GE(temp.size(), num_negative);
          std::stable_sort(temp.begin(), temp.end());
          for (int i = 0; i < num_negative; ++i) {
            anchor_flags[temp[i].index] = 0;  // mark as negative sample
          }
        }
      } else {
        // use all negative samples
        for (int i = 0; i < num_anchors; ++i) {
          if (anchor_flags[i] != 1) {
            anchor_flags[i] = 0;
          }
        }
      }

      // assign training targets
      DType *p_loc_target = loc_target.dptr_ + nbatch * num_anchors * 4;
      DType *p_loc_mask = loc_mask.dptr_ + nbatch * num_anchors * 4;
      DType *p_cls_target = cls_target.dptr_ + nbatch * num_anchors;
      for (int i = 0; i < num_anchors; ++i) {
        if (anchor_flags[i] == 1) {
          // positive sample
          CHECK_GE(max_matches[i].second, 0);
          // 0 reserved for background
          *(p_cls_target + i) = *(p_label + label_width * max_matches[i].second) + 1;
          int offset = i * 4;
          *(p_loc_mask + offset) = 1;
          *(p_loc_mask + offset + 1) = 1;
          *(p_loc_mask + offset + 2) = 1;
          *(p_loc_mask + offset + 3) = 1;
          AssignLocTargets(p_anchor + i * 4,
            p_label + label_width * max_matches[i].second + 1, p_loc_target + offset,
            variances[0], variances[1], variances[2], variances[3]);
        } else if (anchor_flags[i] == 0) {
          // negative sample
          *(p_cls_target + i) = 0;
          int offset = i * 4;
          *(p_loc_mask + offset) = 0;
          *(p_loc_mask + offset + 1) = 0;
          *(p_loc_mask + offset + 2) = 0;
          *(p_loc_mask + offset + 3) = 0;
        }
      }  // end iterate anchors
    }
  }  // end iterate batches
}