void BatchDotBackward_()

in src/operator/tensor/matrix_op-inl.h [597:679]


void BatchDotBackward_(const nnvm::NodeAttrs& attrs,
                       const OpContext& ctx,
                       const std::vector<TBlob>& inputs,
                       const std::vector<OpReqType>& req,
                       const std::vector<TBlob>& outputs) {
  using namespace mshadow;
  using namespace mshadow::expr;
  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
  const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
  CHECK_NE(req[1], kWriteInplace);
  CHECK_NE(req[0], kWriteInplace);
  CHECK(outputs[0].type_flag_ == kFloat32 || outputs[0].type_flag_ == kFloat64)
      << "dot only supports float32 and float64";
  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
    mshadow::Tensor<xpu, 3, DType> mout_grad = inputs[0].get<xpu, 3, DType>(s);
    mshadow::Tensor<xpu, 3, DType> mlhs_data = inputs[1].get<xpu, 3, DType>(s);
    mshadow::Tensor<xpu, 3, DType> mrhs_data = inputs[2].get<xpu, 3, DType>(s);
    mshadow::Tensor<xpu, 3, DType> mlhs_grad = outputs[0].get<xpu, 3, DType>(s);
    mshadow::Tensor<xpu, 3, DType> mrhs_grad = outputs[1].get<xpu, 3, DType>(s);
    mshadow::Tensor<xpu, 2, DType*> workspace =
      ctx.requested[0].get_space_typed<xpu, 2, DType*>(
        mshadow::Shape2(2, 3 * mout_grad.size(0)), s);
    mshadow::Tensor<xpu, 1, DType*> rhs_workspace = workspace[0];
    mshadow::Tensor<xpu, 1, DType*> lhs_workspace = workspace[1];
    if (param.transpose_a && param.transpose_b) {
      // Gradient of z = dot(x.T, y.T)
      // dy = dot(x, dz).T = dot(dz.T, x.T)
      // dx = dot(dz, y).T = dot(y.T, dz.T)
      if (kNullOp != req[1]) {
        mshadow::BatchGEMM<true, true>(mrhs_grad, mout_grad, mlhs_data, (DType)1.0f,
                                        (kAddTo == req[1]) ? (DType)1.0f :  (DType)0.0f,
                                        rhs_workspace);
      }
      if (kNullOp != req[0]) {
        mshadow::BatchGEMM<true, true>(mlhs_grad, mrhs_data, mout_grad, (DType)1.0f,
                                        (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f,
                                        lhs_workspace);
      }
    } else if (!param.transpose_a && param.transpose_b) {
      // Gradient of z = dot(x, y.T)
      // dy = dot(x.T, dz).T = dot(dz.T, x)
      // dx = dot(dz, y)
      if (kNullOp != req[1]) {
        mshadow::BatchGEMM<true, false>(mrhs_grad, mout_grad, mlhs_data, (DType)1.0f,
                                        (kAddTo == req[1]) ? (DType)1.0f : (DType)0.0f,
                                        rhs_workspace);
      }
      if (kNullOp != req[0]) {
        mshadow::BatchGEMM<false, false>(mlhs_grad, mout_grad, mrhs_data, (DType)1.0f,
                                        (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f,
                                        lhs_workspace);
      }
    } else if (param.transpose_a && !param.transpose_b) {
      // Gradient of z = dot(x.T, y)
      // dy = dot(x, dz)
      // dx = dot(dz, y.T).T = dot(y, dz.T)
      if (kNullOp != req[1]) {
        mshadow::BatchGEMM<false, false>(mrhs_grad, mlhs_data, mout_grad, (DType)1.0f,
                                        (kAddTo == req[1]) ? (DType)1.0f : (DType)0.0f,
                                        rhs_workspace);
      }
      if (kNullOp != req[0]) {
        mshadow::BatchGEMM<false, true>(mlhs_grad, mrhs_data, mout_grad, (DType)1.0f,
                                        (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f,
                                        lhs_workspace);
      }
    } else {
      // Gradient of z = dot(x, y)
      // dy = dot(x.T, dz)
      // dx = dot(dz, y.T)
      if (kNullOp != req[1]) {
        mshadow::BatchGEMM<true, false>(mrhs_grad, mlhs_data, mout_grad, (DType)1.0f,
                                        (kAddTo == req[1]) ? (DType)1.0f : (DType)0.0f,
                                        rhs_workspace);
      }
      if (kNullOp != req[0]) {
        mshadow::BatchGEMM<false, true>(mlhs_grad, mout_grad, mrhs_data, (DType)1.0f,
                                        (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f,
                                        lhs_workspace);
      }
    }
  });
}