in include/tvm/topi/einsum.h [681:939]
inline Tensor einsum(const std::string& subscripts_str, const Array<Tensor> inputs,
std::string name = "T_einsum", std::string tag = kEinsum) {
bool back = false;
const char* subscripts = subscripts_str.data();
const char* head = subscripts;
const int nop = inputs.size();
/* Step 1: Parse the subscripts string into label_counts and op_labels */
int iop, idim, min_label = LABELRANGE - 1, max_label = 0;
char label_counts[LABELRANGE], op_labels[NPY_MAXARGS][NPY_MAXDIMS];
memset(label_counts, 0, sizeof(label_counts));
for (iop = 0; iop < nop; ++iop) {
int length = static_cast<int>(strcspn(subscripts, ",-"));
CHECK(!(iop == nop - 1 && subscripts[length] == ','))
<< "more operands provided to einstein sum function "
<< "than specified in the subscripts string";
CHECK(!(iop < nop - 1 && subscripts[length] != ','))
<< "fewer operands provided to einstein sum function "
<< "than specified in the subscripts string";
CHECK_EQ(ParseOperandSubscripts(subscripts, length, inputs[iop + back].ndim(), iop,
op_labels[iop], label_counts, &min_label, &max_label),
0);
/* Move subscripts to the start of the labels for the next op */
subscripts += length;
if (iop < nop - 1) {
CHECK_LT(subscripts - head, subscripts_str.length()) << "subscripts out of range";
subscripts++;
}
}
/*
* Find the number of broadcast dimensions, which is the maximum
* number of labels == 0 in an op_labels array.
*/
int ndim_broadcast = 0;
for (iop = 0; iop < nop; ++iop) {
int count_zeros = 0;
int ndim;
char* labels = op_labels[iop];
ndim = inputs[iop + back].ndim();
for (idim = 0; idim < ndim; ++idim) {
if (labels[idim] == 0) {
++count_zeros;
}
}
if (count_zeros > ndim_broadcast) {
ndim_broadcast = count_zeros;
}
}
/*
* If there is no output signature, fill output_labels and ndim_output
* using each label that appeared once, in alphabetical order.
*/
int label, ndim_output;
char output_labels[NPY_MAXDIMS];
if (subscripts[0] == '\0') {
/* If no output was specified, always broadcast left, as usual. */
for (ndim_output = 0; ndim_output < ndim_broadcast; ++ndim_output) {
output_labels[ndim_output] = 0;
}
for (label = min_label; label <= max_label; ++label) {
if (label_counts[label] == 1) {
CHECK(ndim_output < NPY_MAXDIMS) << "einstein sum subscript string has too many "
<< "distinct labels";
output_labels[ndim_output++] = label;
}
}
} else {
CHECK(subscripts[0] == '-' && subscripts[1] == '>') << "einstein sum subscript string does not "
<< "contain proper '->' output specified";
subscripts += 2;
/* Parse the output subscript string. */
ndim_output = ParseOutputSubscripts(subscripts, strlen(subscripts), ndim_broadcast,
label_counts, output_labels);
CHECK_GE(ndim_output, 0);
}
/*
* Step 2:
* Process all the input ops, combining dimensions into their
* diagonal where specified.
*/
std::vector<Array<PrimExpr>> opshape(nop), opstride_true(nop);
for (iop = 0; iop < nop; ++iop) {
char* labels = op_labels[iop];
int combine, ndim;
ndim = inputs[iop + back].ndim();
/*
* Check whether any dimensions need to be combined
*
* The char type may be either signed or unsigned, we
* need it to be signed here.
*/
combine = 0;
for (idim = 0; idim < ndim; ++idim) {
if ((signed char)labels[idim] < 0) {
combine++;
}
}
/* If any dimensions are combined, create a view which combines them */
if (combine) {
Array<PrimExpr> tshape(static_cast<size_t>(ndim - combine), -1);
Array<PrimExpr> tstride(static_cast<size_t>(ndim - combine), -1);
GetCombinedDimsView(inputs[iop + back], iop, labels, &tshape, &tstride);
opshape[iop] = tshape;
opstride_true[iop] = tstride;
} else {
/* No combining needed */
opshape[iop] = inputs[iop + back]->shape;
opstride_true[iop] = GetStride(opshape[iop]);
}
}
/*
* Step 3:
* Set up the labels for the iterator (output + combined labels).
* Can just share the output_labels memory, because iter_labels
* is output_labels with some more labels appended.
*/
char* iter_labels = output_labels;
int ndim_iter = ndim_output;
for (label = min_label; label <= max_label; ++label) {
if (label_counts[label] > 0 && memchr(output_labels, label, ndim_output) == nullptr) {
CHECK(ndim_iter < NPY_MAXDIMS) << "too many subscripts in einsum";
iter_labels[ndim_iter++] = label;
}
}
/* Step 4: Set up the op_axes for the iterator */
Array<PrimExpr> itershape(static_cast<size_t>(ndim_iter), -1);
std::vector<Array<PrimExpr>> iterstride(nop + 1,
Array<PrimExpr>(static_cast<size_t>(ndim_iter), 0));
// output_shape
std::vector<Array<PrimExpr>> operands;
for (size_t i = 0; i < inputs.size(); i++) {
operands.push_back(inputs[i]->shape);
}
Array<PrimExpr> oshape = NumpyEinsumShape(subscripts_str, operands);
Array<PrimExpr> ostride_true = GetStride(oshape);
Array<PrimExpr> reduceshape;
std::vector<Array<PrimExpr>> remainshape(nop);
int op_axes_arrays[NPY_MAXARGS][NPY_MAXDIMS];
int* op_axes[NPY_MAXARGS];
for (iop = 0; iop < nop; ++iop) {
op_axes[iop] = op_axes_arrays[iop];
CHECK_GE(PrepareOpAxes(opshape[iop].size(), iop, op_labels[iop], op_axes[iop], ndim_iter,
iter_labels),
0);
for (idim = 0; idim < ndim_iter; idim++) {
if (op_axes[iop][idim] != -1) {
iterstride[iop].Set(idim, opstride_true[iop][op_axes[iop][idim]]);
if (GetConstInt(itershape[idim]) != -1) {
if (GetConstInt(itershape[idim]) == 1) {
itershape.Set(idim, opshape[iop][op_axes[iop][idim]]);
}
} else {
itershape.Set(idim, opshape[iop][op_axes[iop][idim]]);
}
}
}
}
for (idim = 0; idim < ndim_output; ++idim) {
iterstride[nop].Set(idim, ostride_true[idim]);
}
reduceshape = Array<PrimExpr>(static_cast<size_t>(ndim_iter - ndim_output), 0);
for (idim = ndim_output; idim < ndim_iter; ++idim) {
reduceshape.Set(idim - ndim_output, itershape[idim]);
}
for (iop = 0; iop < nop; iop++) {
Array<Integer> rsh;
for (idim = 0; idim < ndim_iter; idim++) {
if (op_axes_arrays[iop][idim] == -1) {
rsh.push_back(GetConstInt(itershape[idim]));
} else {
if (GetConstInt(itershape[idim] != opshape[iop][op_axes_arrays[iop][idim]])) {
rsh.push_back(GetConstInt(itershape[idim]));
}
}
}
remainshape[iop] = Array<PrimExpr>(rsh.begin(), rsh.end());
}
// exclude the 0-dim case
if (ndim_iter == 0) {
ndim_iter = 1;
}
itershape = Pad(itershape, ndim_iter);
for (iop = 0; iop <= nop; ++iop) {
iterstride[iop] = Pad(iterstride[iop], ndim_iter);
}
// oshape = Pad(oshape, ndim_iter);
reduceshape = Pad(reduceshape, ndim_iter);
for (iop = 0; iop < nop; ++iop) {
opshape[iop] = Pad(opshape[iop], ndim_iter);
remainshape[iop] = Pad(remainshape[iop], ndim_iter);
}
// ostride and rstride
Array<Array<PrimExpr>> ostride;
Array<Array<PrimExpr>> rstride;
for (iop = 0; iop < nop; ++iop) {
Array<PrimExpr> otmp(static_cast<size_t>(ndim_iter), 0);
Array<PrimExpr> rtmp(static_cast<size_t>(ndim_iter), 0);
for (idim = 0; idim < ndim_iter; ++idim) {
otmp.Set(idim, idim < ndim_output ? iterstride[iop][idim] : 1);
rtmp.Set(idim, idim < ndim_iter - ndim_output ? iterstride[iop][idim + ndim_output] : 1);
}
ostride.push_back(otmp);
rstride.push_back(rtmp);
}
// func: input indices => return cooresponding value
auto func = [inputs, oshape, ostride, reduceshape, ndim_iter, rstride,
nop](const Array<Var>& input_indices) -> PrimExpr {
for (int rdim = 0; rdim < ndim_iter; ++rdim) {
if (GetConstInt(reduceshape[rdim]) == 0) {
return 0; //
}
}
Array<PrimExpr> ridx = UnravelIndex(0, reduceshape);
PrimExpr sum = 0;
bool rec_flag = false;
do {
PrimExpr tmp = 1;
for (int iop = 0; iop < nop; ++iop) {
if (iop != -1) {
PrimExpr k = 0;
for (size_t i = 0; i < input_indices.size(); ++i) {
k += input_indices[i] * ostride[iop][i];
}
for (size_t i = 0; i < ridx.size(); ++i) {
k += ridx[i] * rstride[iop][i];
}
Array<PrimExpr> temp_indices = UnravelIndex(k, inputs[iop]->shape);
tmp = tmp * inputs[iop](temp_indices);
}
}
sum += tmp;
ridx.Set(ridx.size() - 1, ridx[ridx.size() - 1] + 1);
for (int i = static_cast<int>(ridx.size() - 1);
(i > 0) && GetConstInt(ridx[i] >= reduceshape[i]); --i) {
ridx.Set(i, ridx[i] - reduceshape[i]);
ridx.Set(i - 1, ridx[i - 1] + 1);
}
rec_flag = GetConstInt(ridx[0] < reduceshape[0]);
} while (rec_flag);
return sum;
};
return compute(oshape, func, name, tag);
}