src/operator/tensor/la_op.h (377 lines of code) (raw):
/*!
* Copyright (c) 2017 by Contributors
* \file la_op.h
* \brief Operators for advanced linear algebra.
*/
#ifndef MXNET_OPERATOR_TENSOR_LA_OP_H_
#define MXNET_OPERATOR_TENSOR_LA_OP_H_
#include <mxnet/operator_util.h>
#include <vector>
#include <algorithm>
#include "../mshadow_op.h"
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../elemwise_op_common.h"
namespace mxnet {
namespace op {
// Parameters for general matrix-matrix multiply-accumulate (mac)
struct LaMatrixMacParam : public dmlc::Parameter<LaMatrixMacParam> {
bool transpose_a, transpose_b;
double alpha, beta;
DMLC_DECLARE_PARAMETER(LaMatrixMacParam) {
DMLC_DECLARE_FIELD(transpose_a)
.set_default(false)
.describe("Multiply with transposed of first input (A).");
DMLC_DECLARE_FIELD(transpose_b)
.set_default(false)
.describe("Multiply with transposed of second input (B).");
DMLC_DECLARE_FIELD(alpha)
.set_default(1.0)
.describe("Scalar factor multiplied with A*B.");
DMLC_DECLARE_FIELD(beta)
.set_default(1.0)
.describe("Scalar factor multiplied with C.");
}
};
// Parameters for general matrix-matrix multiply
struct LaMatrixMultParam : public dmlc::Parameter<LaMatrixMultParam> {
bool transpose_a, transpose_b;
double alpha;
DMLC_DECLARE_PARAMETER(LaMatrixMultParam) {
DMLC_DECLARE_FIELD(transpose_a)
.set_default(false)
.describe("Multiply with transposed of first input (A).");
DMLC_DECLARE_FIELD(transpose_b)
.set_default(false)
.describe("Multiply with transposed of second input (B).");
DMLC_DECLARE_FIELD(alpha)
.set_default(1.0)
.describe("Scalar factor multiplied with A*B.");
}
};
// Parameters for matrix-matrix multiplication where one is a triangular matrix.
struct LaTriangMatrixMultParam : public dmlc::Parameter<LaTriangMatrixMultParam> {
bool transpose;
bool rightside;
double alpha;
DMLC_DECLARE_PARAMETER(LaTriangMatrixMultParam) {
DMLC_DECLARE_FIELD(transpose)
.set_default(false)
.describe("Use transposed of the triangular matrix");
DMLC_DECLARE_FIELD(rightside)
.set_default(false)
.describe("Multiply triangular matrix from the right to non-triangular one.");
DMLC_DECLARE_FIELD(alpha)
.set_default(1.0)
.describe("Scalar factor to be applied to the result.");
}
};
// Common function for shape inference for matrix mult and matrix mac.
bool LaMatrixMultMacOpShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
CHECK_GE(in_attrs->size(), 2);
CHECK_EQ(out_attrs->size(), 1);
bool transpose_a(false), transpose_b(false);
if ( in_attrs->size() == 2 ) {
// Matrix-Matrix mult
transpose_a = nnvm::get<LaMatrixMultParam>(attrs.parsed).transpose_a;
transpose_b = nnvm::get<LaMatrixMultParam>(attrs.parsed).transpose_b;
} else {
// Matrix-Matrix mac
transpose_a = nnvm::get<LaMatrixMacParam>(attrs.parsed).transpose_a;
transpose_b = nnvm::get<LaMatrixMacParam>(attrs.parsed).transpose_b;
}
if ( (*in_attrs)[0].ndim() >= 2 && (*in_attrs)[0].ndim() == (*in_attrs)[1].ndim() ) {
// Forward shape inference.
const int ndim((*in_attrs)[0].ndim());
std::vector<int> oshape(ndim);
for ( int i = 0; i < ndim-2; ++i ) {
// Both inputs must have same shape except for last two dimensions.
if ( (*in_attrs)[0][i] != (*in_attrs)[1][i] ) return false;
oshape[i] = (*in_attrs)[0][i];
}
CHECK_EQ((transpose_a ? (*in_attrs)[0][ndim-2] : (*in_attrs)[0][ndim-1]),
(transpose_b ? (*in_attrs)[1][ndim-1] : (*in_attrs)[1][ndim-2]))
<< "Incompatible matrix dimensions for multiplication";
oshape[ndim-2] = (transpose_a ? (*in_attrs)[0][ndim-1] : (*in_attrs)[0][ndim-2]);
oshape[ndim-1] = (transpose_b ? (*in_attrs)[1][ndim-2] : (*in_attrs)[1][ndim-1]);
TShape tshape(oshape.begin(), oshape.end());
SHAPE_ASSIGN_CHECK(*out_attrs, 0, tshape);
if ( in_attrs->size() > 2 ) {
// Infer/check shape of third operand of a mac.
SHAPE_ASSIGN_CHECK(*in_attrs, 2, tshape);
}
return true;
}
// Can't do backward inference of shapes for this operator.
return false;
}
bool LaTriangMatrixMultOpShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
const LaTriangMatrixMultParam& param = nnvm::get<LaTriangMatrixMultParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 2);
CHECK_EQ(out_attrs->size(), 1);
if ( (*in_attrs)[0].ndim() >= 2 && (*in_attrs)[0].ndim() == (*in_attrs)[1].ndim() ) {
// Forward shape inference.
const int ndim((*in_attrs)[0].ndim());
CHECK_EQ((*in_attrs)[0][ndim-2], (*in_attrs)[0][ndim-1])
<< "First operand must be a tensor of square matrices";
std::vector<int> oshape(ndim);
for ( int i = 0; i < ndim-2; ++i ) {
// Must have same shape except for last two dimensions.
if ( (*in_attrs)[0][i] != (*in_attrs)[1][i] ) return false;
oshape[i] = (*in_attrs)[0][i];
}
if ( param.rightside ) {
// We compute B * A where A is the first and B the second input.
CHECK_EQ((*in_attrs)[0][ndim-2], (*in_attrs)[1][ndim-1])
<< "Incompatible matrix dimensions for multiplication";
oshape[ndim-2] = (*in_attrs)[1][ndim-2];
oshape[ndim-1] = (param.transpose ? (*in_attrs)[0][ndim-2] : (*in_attrs)[0][ndim-1]);
} else {
// We compute A * B where A is the first and B the second input.
CHECK_EQ((*in_attrs)[1][ndim-2], (*in_attrs)[0][ndim-1])
<< "Incompatible matrix dimensions for multiplication";
oshape[ndim-2] = (param.transpose ? (*in_attrs)[0][ndim-1] : (*in_attrs)[0][ndim-2]);
oshape[ndim-1] = (*in_attrs)[1][ndim-1];
}
TShape tshape(oshape.begin(), oshape.end());
SHAPE_ASSIGN_CHECK(*out_attrs, 0, tshape);
return true;
}
if ( (*out_attrs)[0].ndim() >= 2 ) {
// Backward shape inference.
const int odim((*out_attrs)[0].ndim());
std::vector<int> ishape1(odim), ishape2(odim);
for ( int i = 0; i < odim-2; ++i ) {
ishape1[i] = ishape2[i] = (*out_attrs)[0][i];
}
if ( param.rightside ) {
// We compute B * A where A is the first and B the second input.
ishape2[odim-2] = (*out_attrs)[0][odim-2];
ishape1[odim-2] = ishape1[odim-1] = ishape2[odim-1] = (*out_attrs)[0][odim-1];
} else {
// We compute A * B where A is the first and B the second input.
ishape2[odim-1] = (*out_attrs)[0][odim-1];
ishape1[odim-2] = ishape1[odim-1] = ishape2[odim-2] = (*out_attrs)[0][odim-2];
}
TShape tshape1(ishape1.begin(), ishape1.end());
SHAPE_ASSIGN_CHECK(*in_attrs, 0, tshape1);
TShape tshape2(ishape2.begin(), ishape2.end());
SHAPE_ASSIGN_CHECK(*in_attrs, 1, tshape2);
return true;
}
return false;
}
template<int dim>
bool LaReduceShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
// Shape for reduction of the dim lowest dimensions to a scalar.
// Can only deduct in forward direction.
CHECK_EQ(in_attrs->size(), 1);
CHECK_EQ(out_attrs->size(), 1);
const int ndim((*in_attrs)[0].ndim());
if ( ndim < dim ) {
return false;
}
std::vector<int> oshape(std::max(1, ndim-dim), 1);
for ( int i = 0; i < ndim - dim; ++i ) {
oshape[i] = (*in_attrs)[0][i];
}
// Will reduce all matrices/vectors to a scalar.
TShape tshape(oshape.begin(), oshape.end());
SHAPE_ASSIGN_CHECK(*out_attrs, 0, tshape);
return true;
}
// Adapters for calling the various operators with appropriate signatures.
template<typename xpu, typename DType, int idim, int odim, int inum, int onum, typename laop>
struct LaOpCaller {
static void op(const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs,
const int index,
const nnvm::NodeAttrs& attrs,
mshadow::Stream<xpu> *s) {
CHECK(false) << "no specialized LaOpCaller defined for template parameters";
}
};
template<typename xpu, typename DType, int idim, int odim, typename laop>
struct LaOpCaller<xpu, DType, idim, odim, 1, 1, laop> {
static void op(const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs,
const int index,
const nnvm::NodeAttrs& attrs,
mshadow::Stream<xpu> *s) {
laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s)[index],
outputs[0].FlatToKD<xpu, odim+1, DType>(s)[index], attrs);
}
};
template<typename xpu, typename DType, int idim, int odim, typename laop>
struct LaOpCaller<xpu, DType, idim, odim, 2, 1, laop> {
static void op(const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs,
const int index,
const nnvm::NodeAttrs& attrs,
mshadow::Stream<xpu> *s) {
laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s)[index],
inputs[1].FlatToKD<xpu, idim+1, DType>(s)[index],
outputs[0].FlatToKD<xpu, odim+1, DType>(s)[index],
attrs);
}
};
template<typename xpu, typename DType, int idim, int odim, typename laop>
struct LaOpCaller<xpu, DType, idim, odim, 3, 1, laop> {
static void op(const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs,
const int index,
const nnvm::NodeAttrs& attrs,
mshadow::Stream<xpu> *s) {
laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s)[index],
inputs[1].FlatToKD<xpu, idim+1, DType>(s)[index],
inputs[2].FlatToKD<xpu, idim+1, DType>(s)[index],
outputs[0].FlatToKD<xpu, odim+1, DType>(s)[index],
attrs);
}
};
template<typename xpu, typename DType, int idim, int odim, typename laop>
struct LaOpCaller<xpu, DType, idim, odim, 3, 2, laop> {
static void op(const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs,
const int index,
const nnvm::NodeAttrs& attrs,
mshadow::Stream<xpu> *s) {
laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s)[index],
inputs[1].FlatToKD<xpu, idim+1, DType>(s)[index],
inputs[2].FlatToKD<xpu, idim+1, DType>(s)[index],
outputs[0].FlatToKD<xpu, odim+1, DType>(s)[index],
outputs[1].FlatToKD<xpu, odim+1, DType>(s)[index],
attrs);
}
};
template<typename xpu, typename DType, int idim, int odim, typename laop>
struct LaOpCaller<xpu, DType, idim, odim, 4, 2, laop> {
static void op(const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs,
const int index,
const nnvm::NodeAttrs& attrs,
mshadow::Stream<xpu> *s) {
laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s)[index],
inputs[1].FlatToKD<xpu, idim+1, DType>(s)[index],
inputs[2].FlatToKD<xpu, idim+1, DType>(s)[index],
inputs[3].FlatToKD<xpu, idim+1, DType>(s)[index],
outputs[0].FlatToKD<xpu, odim+1, DType>(s)[index],
outputs[1].FlatToKD<xpu, odim+1, DType>(s)[index],
attrs);
}
};
template<typename xpu, typename DType, int idim, int odim, typename laop>
struct LaOpCaller<xpu, DType, idim, odim, 4, 3, laop> {
static void op(const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs,
const int index,
const nnvm::NodeAttrs& attrs,
mshadow::Stream<xpu> *s) {
laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s)[index],
inputs[1].FlatToKD<xpu, idim+1, DType>(s)[index],
inputs[2].FlatToKD<xpu, idim+1, DType>(s)[index],
inputs[3].FlatToKD<xpu, idim+1, DType>(s)[index],
outputs[0].FlatToKD<xpu, odim+1, DType>(s)[index],
outputs[1].FlatToKD<xpu, odim+1, DType>(s)[index],
outputs[2].FlatToKD<xpu, odim+1, DType>(s)[index],
attrs);
}
};
template<typename xpu, int idim, int odim, int inum, int onum, typename laop>
void LaOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
Stream<xpu> *s = ctx.get_stream<xpu>();
CHECK_EQ(inputs.size(), inum);
CHECK_EQ(outputs.size(), onum);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
int N(-1);
for ( int i = 0; i < inum; ++i ) {
CHECK_EQ(inputs[i].CheckContiguous(), true);
const int M(inputs[i].FlatToKD<xpu, idim+1, OType>(s).size(0));
CHECK_EQ((N == -1 || N == M), true);
N = M;
}
for ( int i = 0; i < onum; ++i ) {
CHECK_EQ(outputs[i].CheckContiguous(), true);
CHECK_EQ((req[i] == kWriteTo || req[i] == kWriteInplace), true);
const int M(outputs[i].FlatToKD<xpu, odim+1, OType>(s).size(0));
CHECK_EQ((N == -1 || N == M), true);
N = M;
}
for ( int i = 0; i < N; ++i ) {
LaOpCaller<xpu, OType, idim, odim, inum, onum, laop>::op(inputs, outputs, i, attrs, s);
}
});
}
template<typename xpu, int idim, int odim, int inum, int onum, typename laop>
void LaOpBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
Stream<xpu> *s = ctx.get_stream<xpu>();
CHECK_EQ(inputs.size(), inum);
CHECK_EQ(outputs.size(), onum);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
int N(-1);
for ( int i = 0; i < inum; ++i ) {
CHECK_EQ(inputs[i].CheckContiguous(), true);
const int M(inputs[i].FlatToKD<xpu, idim+1, OType>(s).size(0));
CHECK_EQ((N == -1 || N == M), true);
N = M;
}
std::vector<TBlob> tspace(outputs);
for ( int i = 0; i < onum; ++i ) {
CHECK_EQ(outputs[i].CheckContiguous(), true);
const int M(outputs[i].FlatToKD<xpu, odim+1, OType>(s).size(0));
CHECK_EQ((N == -1 || N == M), true);
N = M;
if ( req[i] == kAddTo ) {
tspace[i].dptr_ = ctx.requested[ResourceRequest::kTempSpace]
.get_space_typed<xpu, 1, OType>(Shape1(outputs[i].Size()), s).dptr_;
}
}
for ( int i = 0; i < N; ++i ) {
LaOpCaller<xpu, OType, idim, odim, inum, onum, laop>::op(inputs, tspace, i, attrs, s);
}
for ( int i = 0; i < onum; ++i ) {
if ( req[i] == kAddTo ) {
Tensor<xpu, 1, OType> out = outputs[i].FlatTo1D<xpu, OType>(s);
out += tspace[i].FlatTo1D<xpu, OType>(s);
}
}
});
}
template<typename xpu, int idim, typename laop>
void LaReduceForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
Stream<xpu> *s = ctx.get_stream<xpu>();
CHECK_EQ(inputs.size(), 1);
CHECK_EQ(outputs.size(), 1);
CHECK_EQ(inputs[0].CheckContiguous(), true);
CHECK_EQ(outputs[0].CheckContiguous(), true);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
Tensor<xpu, idim+1, OType> in(inputs[0].FlatToKD<xpu, idim+1, OType>(s));
Tensor<xpu, 1, OType> out(outputs[0].FlatTo1D<xpu, OType>(s));
const int N(outputs[0].Size());
CHECK_EQ(in.size(0), N);
for ( int i = 0; i < N; ++i ) {
laop::op(in[i], out[i], attrs);
}
});
}
template<typename xpu, int idim, typename laop>
void LaReduceBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
Stream<xpu> *s = ctx.get_stream<xpu>();
CHECK_EQ(inputs.size(), 2);
CHECK_EQ(outputs.size(), 1);
CHECK_EQ(inputs[0].CheckContiguous(), true);
CHECK_EQ(inputs[1].CheckContiguous(), true);
CHECK_EQ(outputs[0].CheckContiguous(), true);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
const int N(inputs[0].Size());
Tensor<xpu, 1, OType> in0(inputs[0].FlatTo1D<xpu, OType>(s));
Tensor<xpu, idim+1, OType> in1(inputs[1].FlatToKD<xpu, idim+1, OType>(s));
Tensor<xpu, idim+1, OType> out(outputs[0].FlatToKD<xpu, idim+1, OType>(s));
for ( int i = 0; i < N; ++i ) {
laop::op(in0[i], in1[i], out[i], attrs, (req[i] == kAddTo));
}
});
}
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_TENSOR_LA_OP_H_