in Sources/x10/xla_tensor/xla_lower_util.cpp [396:440]
xla::XlaOp CreateMatMul(xla::XlaOp lhs, xla::XlaOp rhs) {
xla::PrecisionConfig precision_config =
XlaHelpers::BuildPrecisionConfig(XlaHelpers::mat_mul_precision());
xla::Shape lhs_shape = XlaHelpers::ShapeOfXlaOp(lhs);
xla::Shape rhs_shape = XlaHelpers::ShapeOfXlaOp(rhs);
if ((lhs_shape.rank() == 1 && rhs_shape.rank() == 1) ||
(lhs_shape.rank() == 2 && rhs_shape.rank() == 2) ||
(lhs_shape.rank() == 2 && rhs_shape.rank() == 1)) {
return BuildDot(lhs, rhs);
}
if (lhs_shape.rank() == 1 && rhs_shape.rank() == 2) {
xla::XlaOp reshaped_lhs =
XlaHelpers::DynamicReshape(lhs, {1, lhs_shape.dimensions(0)});
return XlaHelpers::DynamicReshape(BuildDot(reshaped_lhs, rhs),
{rhs_shape.dimensions(1)});
}
if (lhs_shape.rank() >= 1 && rhs_shape.rank() >= 1 &&
(lhs_shape.rank() >= 3 || rhs_shape.rank() >= 3)) {
xla::XlaOp reshaped_lhs = lhs;
xla::XlaOp reshaped_rhs = rhs;
if (lhs_shape.rank() > rhs_shape.rank()) {
reshaped_rhs = DotExpand(reshaped_rhs, rhs_shape, lhs_shape);
rhs_shape = XlaHelpers::ShapeOfXlaOp(reshaped_rhs);
} else if (rhs_shape.rank() > lhs_shape.rank()) {
reshaped_lhs = DotExpand(reshaped_lhs, lhs_shape, rhs_shape);
lhs_shape = XlaHelpers::ShapeOfXlaOp(reshaped_lhs);
}
std::tie(reshaped_lhs, reshaped_rhs) =
DotBroadcast(reshaped_lhs, lhs_shape, reshaped_rhs, rhs_shape);
// At this point lhs and rhs ranks are the same, use left rank in code
// below.
xla::DotDimensionNumbers dims;
for (xla::int64 i = 0; i < lhs_shape.rank() - 2; ++i) {
dims.add_lhs_batch_dimensions(i);
dims.add_rhs_batch_dimensions(i);
}
dims.add_lhs_contracting_dimensions(lhs_shape.rank() - 1);
dims.add_rhs_contracting_dimensions(lhs_shape.rank() - 2);
return xla::DotGeneral(reshaped_lhs, reshaped_rhs, dims, &precision_config);
}
XLA_ERROR() << "Unsupported matmul operation: matmul(" << lhs_shape << ", "
<< rhs_shape << ")";
}