inline Tensor einsum()

in include/tvm/topi/einsum.h [681:939]


inline Tensor einsum(const std::string& subscripts_str, const Array<Tensor> inputs,
                     std::string name = "T_einsum", std::string tag = kEinsum) {
  bool back = false;
  const char* subscripts = subscripts_str.data();
  const char* head = subscripts;
  const int nop = inputs.size();

  /* Step 1: Parse the subscripts string into label_counts and op_labels */
  int iop, idim, min_label = LABELRANGE - 1, max_label = 0;
  char label_counts[LABELRANGE], op_labels[NPY_MAXARGS][NPY_MAXDIMS];
  memset(label_counts, 0, sizeof(label_counts));
  for (iop = 0; iop < nop; ++iop) {
    int length = static_cast<int>(strcspn(subscripts, ",-"));

    CHECK(!(iop == nop - 1 && subscripts[length] == ','))
        << "more operands provided to einstein sum function "
        << "than specified in the subscripts string";
    CHECK(!(iop < nop - 1 && subscripts[length] != ','))
        << "fewer operands provided to einstein sum function "
        << "than specified in the subscripts string";
    CHECK_EQ(ParseOperandSubscripts(subscripts, length, inputs[iop + back].ndim(), iop,
                                    op_labels[iop], label_counts, &min_label, &max_label),
             0);

    /* Move subscripts to the start of the labels for the next op */
    subscripts += length;

    if (iop < nop - 1) {
      CHECK_LT(subscripts - head, subscripts_str.length()) << "subscripts out of range";
      subscripts++;
    }
  }
  /*
   * Find the number of broadcast dimensions, which is the maximum
   * number of labels == 0 in an op_labels array.
   */
  int ndim_broadcast = 0;
  for (iop = 0; iop < nop; ++iop) {
    int count_zeros = 0;
    int ndim;
    char* labels = op_labels[iop];

    ndim = inputs[iop + back].ndim();
    for (idim = 0; idim < ndim; ++idim) {
      if (labels[idim] == 0) {
        ++count_zeros;
      }
    }

    if (count_zeros > ndim_broadcast) {
      ndim_broadcast = count_zeros;
    }
  }

  /*
   * If there is no output signature, fill output_labels and ndim_output
   * using each label that appeared once, in alphabetical order.
   */
  int label, ndim_output;
  char output_labels[NPY_MAXDIMS];
  if (subscripts[0] == '\0') {
    /* If no output was specified, always broadcast left, as usual. */
    for (ndim_output = 0; ndim_output < ndim_broadcast; ++ndim_output) {
      output_labels[ndim_output] = 0;
    }
    for (label = min_label; label <= max_label; ++label) {
      if (label_counts[label] == 1) {
        CHECK(ndim_output < NPY_MAXDIMS) << "einstein sum subscript string has too many "
                                         << "distinct labels";
        output_labels[ndim_output++] = label;
      }
    }
  } else {
    CHECK(subscripts[0] == '-' && subscripts[1] == '>') << "einstein sum subscript string does not "
                                                        << "contain proper '->' output specified";
    subscripts += 2;

    /* Parse the output subscript string. */
    ndim_output = ParseOutputSubscripts(subscripts, strlen(subscripts), ndim_broadcast,
                                        label_counts, output_labels);
    CHECK_GE(ndim_output, 0);
  }

  /*
   * Step 2:
   * Process all the input ops, combining dimensions into their
   * diagonal where specified.
   */
  std::vector<Array<PrimExpr>> opshape(nop), opstride_true(nop);
  for (iop = 0; iop < nop; ++iop) {
    char* labels = op_labels[iop];
    int combine, ndim;

    ndim = inputs[iop + back].ndim();

    /*
     * Check whether any dimensions need to be combined
     *
     * The char type may be either signed or unsigned, we
     * need it to be signed here.
     */
    combine = 0;
    for (idim = 0; idim < ndim; ++idim) {
      if ((signed char)labels[idim] < 0) {
        combine++;
      }
    }
    /* If any dimensions are combined, create a view which combines them */
    if (combine) {
      Array<PrimExpr> tshape(static_cast<size_t>(ndim - combine), -1);
      Array<PrimExpr> tstride(static_cast<size_t>(ndim - combine), -1);
      GetCombinedDimsView(inputs[iop + back], iop, labels, &tshape, &tstride);
      opshape[iop] = tshape;
      opstride_true[iop] = tstride;
    } else {
      /* No combining needed */
      opshape[iop] = inputs[iop + back]->shape;
      opstride_true[iop] = GetStride(opshape[iop]);
    }
  }
  /*
   * Step 3:
   * Set up the labels for the iterator (output + combined labels).
   * Can just share the output_labels memory, because iter_labels
   * is output_labels with some more labels appended.
   */
  char* iter_labels = output_labels;
  int ndim_iter = ndim_output;
  for (label = min_label; label <= max_label; ++label) {
    if (label_counts[label] > 0 && memchr(output_labels, label, ndim_output) == nullptr) {
      CHECK(ndim_iter < NPY_MAXDIMS) << "too many subscripts in einsum";
      iter_labels[ndim_iter++] = label;
    }
  }
  /* Step 4: Set up the op_axes for the iterator */
  Array<PrimExpr> itershape(static_cast<size_t>(ndim_iter), -1);
  std::vector<Array<PrimExpr>> iterstride(nop + 1,
                                          Array<PrimExpr>(static_cast<size_t>(ndim_iter), 0));

  // output_shape
  std::vector<Array<PrimExpr>> operands;
  for (size_t i = 0; i < inputs.size(); i++) {
    operands.push_back(inputs[i]->shape);
  }
  Array<PrimExpr> oshape = NumpyEinsumShape(subscripts_str, operands);
  Array<PrimExpr> ostride_true = GetStride(oshape);
  Array<PrimExpr> reduceshape;
  std::vector<Array<PrimExpr>> remainshape(nop);
  int op_axes_arrays[NPY_MAXARGS][NPY_MAXDIMS];
  int* op_axes[NPY_MAXARGS];
  for (iop = 0; iop < nop; ++iop) {
    op_axes[iop] = op_axes_arrays[iop];
    CHECK_GE(PrepareOpAxes(opshape[iop].size(), iop, op_labels[iop], op_axes[iop], ndim_iter,
                           iter_labels),
             0);
    for (idim = 0; idim < ndim_iter; idim++) {
      if (op_axes[iop][idim] != -1) {
        iterstride[iop].Set(idim, opstride_true[iop][op_axes[iop][idim]]);
        if (GetConstInt(itershape[idim]) != -1) {
          if (GetConstInt(itershape[idim]) == 1) {
            itershape.Set(idim, opshape[iop][op_axes[iop][idim]]);
          }
        } else {
          itershape.Set(idim, opshape[iop][op_axes[iop][idim]]);
        }
      }
    }
  }
  for (idim = 0; idim < ndim_output; ++idim) {
    iterstride[nop].Set(idim, ostride_true[idim]);
  }
  reduceshape = Array<PrimExpr>(static_cast<size_t>(ndim_iter - ndim_output), 0);
  for (idim = ndim_output; idim < ndim_iter; ++idim) {
    reduceshape.Set(idim - ndim_output, itershape[idim]);
  }
  for (iop = 0; iop < nop; iop++) {
    Array<Integer> rsh;
    for (idim = 0; idim < ndim_iter; idim++) {
      if (op_axes_arrays[iop][idim] == -1) {
        rsh.push_back(GetConstInt(itershape[idim]));
      } else {
        if (GetConstInt(itershape[idim] != opshape[iop][op_axes_arrays[iop][idim]])) {
          rsh.push_back(GetConstInt(itershape[idim]));
        }
      }
    }
    remainshape[iop] = Array<PrimExpr>(rsh.begin(), rsh.end());
  }
  // exclude the 0-dim case
  if (ndim_iter == 0) {
    ndim_iter = 1;
  }
  itershape = Pad(itershape, ndim_iter);
  for (iop = 0; iop <= nop; ++iop) {
    iterstride[iop] = Pad(iterstride[iop], ndim_iter);
  }
  // oshape = Pad(oshape, ndim_iter);
  reduceshape = Pad(reduceshape, ndim_iter);
  for (iop = 0; iop < nop; ++iop) {
    opshape[iop] = Pad(opshape[iop], ndim_iter);
    remainshape[iop] = Pad(remainshape[iop], ndim_iter);
  }
  // ostride and rstride
  Array<Array<PrimExpr>> ostride;
  Array<Array<PrimExpr>> rstride;

  for (iop = 0; iop < nop; ++iop) {
    Array<PrimExpr> otmp(static_cast<size_t>(ndim_iter), 0);
    Array<PrimExpr> rtmp(static_cast<size_t>(ndim_iter), 0);
    for (idim = 0; idim < ndim_iter; ++idim) {
      otmp.Set(idim, idim < ndim_output ? iterstride[iop][idim] : 1);
      rtmp.Set(idim, idim < ndim_iter - ndim_output ? iterstride[iop][idim + ndim_output] : 1);
    }
    ostride.push_back(otmp);
    rstride.push_back(rtmp);
  }

  // func: input indices => return cooresponding value
  auto func = [inputs, oshape, ostride, reduceshape, ndim_iter, rstride,
               nop](const Array<Var>& input_indices) -> PrimExpr {
    for (int rdim = 0; rdim < ndim_iter; ++rdim) {
      if (GetConstInt(reduceshape[rdim]) == 0) {
        return 0;  //
      }
    }
    Array<PrimExpr> ridx = UnravelIndex(0, reduceshape);

    PrimExpr sum = 0;
    bool rec_flag = false;
    do {
      PrimExpr tmp = 1;
      for (int iop = 0; iop < nop; ++iop) {
        if (iop != -1) {
          PrimExpr k = 0;

          for (size_t i = 0; i < input_indices.size(); ++i) {
            k += input_indices[i] * ostride[iop][i];
          }
          for (size_t i = 0; i < ridx.size(); ++i) {
            k += ridx[i] * rstride[iop][i];
          }
          Array<PrimExpr> temp_indices = UnravelIndex(k, inputs[iop]->shape);
          tmp = tmp * inputs[iop](temp_indices);
        }
      }
      sum += tmp;
      ridx.Set(ridx.size() - 1, ridx[ridx.size() - 1] + 1);
      for (int i = static_cast<int>(ridx.size() - 1);
           (i > 0) && GetConstInt(ridx[i] >= reduceshape[i]); --i) {
        ridx.Set(i, ridx[i] - reduceshape[i]);
        ridx.Set(i - 1, ridx[i - 1] + 1);
      }
      rec_flag = GetConstInt(ridx[0] < reduceshape[0]);
    } while (rec_flag);
    return sum;
  };

  return compute(oshape, func, name, tag);
}