in src/operator/tensor/matrix_op-inl.h [597:679]
void BatchDotBackward_(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
CHECK_NE(req[1], kWriteInplace);
CHECK_NE(req[0], kWriteInplace);
CHECK(outputs[0].type_flag_ == kFloat32 || outputs[0].type_flag_ == kFloat64)
<< "dot only supports float32 and float64";
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
mshadow::Tensor<xpu, 3, DType> mout_grad = inputs[0].get<xpu, 3, DType>(s);
mshadow::Tensor<xpu, 3, DType> mlhs_data = inputs[1].get<xpu, 3, DType>(s);
mshadow::Tensor<xpu, 3, DType> mrhs_data = inputs[2].get<xpu, 3, DType>(s);
mshadow::Tensor<xpu, 3, DType> mlhs_grad = outputs[0].get<xpu, 3, DType>(s);
mshadow::Tensor<xpu, 3, DType> mrhs_grad = outputs[1].get<xpu, 3, DType>(s);
mshadow::Tensor<xpu, 2, DType*> workspace =
ctx.requested[0].get_space_typed<xpu, 2, DType*>(
mshadow::Shape2(2, 3 * mout_grad.size(0)), s);
mshadow::Tensor<xpu, 1, DType*> rhs_workspace = workspace[0];
mshadow::Tensor<xpu, 1, DType*> lhs_workspace = workspace[1];
if (param.transpose_a && param.transpose_b) {
// Gradient of z = dot(x.T, y.T)
// dy = dot(x, dz).T = dot(dz.T, x.T)
// dx = dot(dz, y).T = dot(y.T, dz.T)
if (kNullOp != req[1]) {
mshadow::BatchGEMM<true, true>(mrhs_grad, mout_grad, mlhs_data, (DType)1.0f,
(kAddTo == req[1]) ? (DType)1.0f : (DType)0.0f,
rhs_workspace);
}
if (kNullOp != req[0]) {
mshadow::BatchGEMM<true, true>(mlhs_grad, mrhs_data, mout_grad, (DType)1.0f,
(kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f,
lhs_workspace);
}
} else if (!param.transpose_a && param.transpose_b) {
// Gradient of z = dot(x, y.T)
// dy = dot(x.T, dz).T = dot(dz.T, x)
// dx = dot(dz, y)
if (kNullOp != req[1]) {
mshadow::BatchGEMM<true, false>(mrhs_grad, mout_grad, mlhs_data, (DType)1.0f,
(kAddTo == req[1]) ? (DType)1.0f : (DType)0.0f,
rhs_workspace);
}
if (kNullOp != req[0]) {
mshadow::BatchGEMM<false, false>(mlhs_grad, mout_grad, mrhs_data, (DType)1.0f,
(kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f,
lhs_workspace);
}
} else if (param.transpose_a && !param.transpose_b) {
// Gradient of z = dot(x.T, y)
// dy = dot(x, dz)
// dx = dot(dz, y.T).T = dot(y, dz.T)
if (kNullOp != req[1]) {
mshadow::BatchGEMM<false, false>(mrhs_grad, mlhs_data, mout_grad, (DType)1.0f,
(kAddTo == req[1]) ? (DType)1.0f : (DType)0.0f,
rhs_workspace);
}
if (kNullOp != req[0]) {
mshadow::BatchGEMM<false, true>(mlhs_grad, mrhs_data, mout_grad, (DType)1.0f,
(kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f,
lhs_workspace);
}
} else {
// Gradient of z = dot(x, y)
// dy = dot(x.T, dz)
// dx = dot(dz, y.T)
if (kNullOp != req[1]) {
mshadow::BatchGEMM<true, false>(mrhs_grad, mlhs_data, mout_grad, (DType)1.0f,
(kAddTo == req[1]) ? (DType)1.0f : (DType)0.0f,
rhs_workspace);
}
if (kNullOp != req[0]) {
mshadow::BatchGEMM<false, true>(mlhs_grad, mout_grad, mrhs_data, (DType)1.0f,
(kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f,
lhs_workspace);
}
}
});
}