void inputLayerRules()

in sparseconvnet/SCN/Metadata/IOLayersRules.h [19:121]


void inputLayerRules(SparseGrids<dimension> &SGs, RuleBook &rules, long *coords,
		     Int nInputRows, Int nInputColumns, Int batchSize, Int mode,
		     Int &nActive) {
  assert(nActive == 0);
  assert(rules.size() == 0);
  assert(SGs.size() == 0);
  SGs.resize(batchSize); // Set a minimum batch size if necessary
  Point<dimension> p;

  if (mode == 0) {
    nActive = nInputRows;
    rules.resize(1);
    rules[0].push_back(mode);
    rules[0].push_back(1);
    rules[0].push_back(nInputRows);
    rules[0].push_back(nInputRows);

    if (nInputColumns == dimension) {
      SGs.resize(1);
      auto &sg = SGs[0];
      for (Int i = 0; i < nInputRows; ++i) {
	for (Int j = 0; j < dimension; j++)
	  p[j] = coords[j];
	coords += dimension;
	sg.mp[p] = i;
      }
    } else { // nInputColumns == dimension + 1
      Int idx;
      for (Int i = 0; i < nInputRows; ++i) {
	for (Int j = 0; j < dimension; j++)
	  p[j] = coords[j];
	idx = coords[dimension];
	coords += dimension + 1;
	if (idx + 1 >= (Int)SGs.size())
	  SGs.resize(idx + 1);
	SGs[idx].mp[p] = i;
      }
    }
    return;
  }

  // Compile list of how input rows correspond to output rows
  std::vector<std::vector<Int>> outputRows;
  if (nInputColumns == dimension) {
    SGs.resize(1);
    auto &sg = SGs[0];
    for (Int i = 0; i < nInputRows; ++i) {
      for (Int j = 0; j < dimension; j++)
	p[j] = coords[j];
      coords += dimension;
      if (sg.mp.insert(make_pair(p, nActive)).second) {
	outputRows.resize(++nActive);
      }
      outputRows[sg.mp[p]].push_back(i);
    }
  } else { // nInputColumns == dimension + 1
    Int idx;
    for (Int i = 0; i < nInputRows; ++i) {
      for (Int j = 0; j < dimension; j++)
	p[j] = coords[j];
      idx = coords[dimension];
      coords += dimension + 1;
      if (idx + 1 >= (Int)SGs.size())
	SGs.resize(idx + 1);
      auto &sg = SGs[idx];
      if (sg.mp.insert(make_pair(p, nActive)).second) {
	outputRows.resize(++nActive);
      }
      outputRows[sg.mp[p]].push_back(i);
    }
  }
  rules.resize(2);
  rules[0].push_back(mode);
  rules[0].push_back(1); // replace with maxActive if mode==3 or 4
  rules[0].push_back(nInputRows);
  rules[0].push_back(outputRows.size());
  auto &rule = rules[1];
  if (mode == 1) {
    for (Int i = 0; i < nActive; ++i) {
      rule.push_back(1);
      rule.push_back(outputRows[i].front());
    }
  }
  if (mode == 2) {
    for (Int i = 0; i < nActive; ++i) {
      rule.push_back(1);
      rule.push_back(outputRows[i].back());
    }
  }
  if (mode == 3 or mode == 4) {
    Int maxActive = 0;
    for (auto &row : outputRows)
      maxActive = std::max(maxActive, (Int)row.size());
    rules[0][1] = maxActive;
    for (auto &row : outputRows) {
      rule.push_back(row.size());
      for (auto &r : row)
	rule.push_back(r);
      rule.resize((rule.size() + maxActive) / (maxActive + 1) *
		  (maxActive + 1));
    }
  }
}