src/operator/random/sample_multinomial_op.h (147 lines of code) (raw):
/*!
* Copyright (c) 2017 by Contributors
* \file sample_multinomial_op.h
* \brief Operator for sampling from multinomial distributions
*/
#ifndef MXNET_OPERATOR_RANDOM_SAMPLE_MULTINOMIAL_OP_H_
#define MXNET_OPERATOR_RANDOM_SAMPLE_MULTINOMIAL_OP_H_
#include <mxnet/operator_util.h>
#include <vector>
#include "../mshadow_op.h"
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../elemwise_op_common.h"
namespace mxnet {
namespace op {
struct SampleMultinomialParam : public dmlc::Parameter<SampleMultinomialParam> {
TShape shape;
bool get_prob;
int dtype;
DMLC_DECLARE_PARAMETER(SampleMultinomialParam) {
DMLC_DECLARE_FIELD(shape)
.set_default(TShape())
.describe("Shape to be sampled from each random distribution.");
DMLC_DECLARE_FIELD(get_prob)
.set_default(false)
.describe("Whether to also return the log probability of sampled "
"result. This is usually used for differentiating through "
"stochastic variables, e.g. in reinforcement learning.");
DMLC_DECLARE_FIELD(dtype)
.add_enum("int32", mshadow::kInt32)
.set_default(mshadow::kInt32)
.describe("DType of the output in case this can't be inferred. "
"Only support int32 for now.");
}
};
inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
const SampleMultinomialParam& param = nnvm::get<SampleMultinomialParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), param.get_prob ? 2U : 1U);
const TShape& ishape = (*in_attrs)[0];
if (!ishape.ndim()) return false;
if (ishape.ndim() == 1) {
if (param.shape.ndim()) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, param.shape);
if (param.get_prob) SHAPE_ASSIGN_CHECK(*out_attrs, 0, param.shape);
} else {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(1));
if (param.get_prob) SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(1));
}
return true;
}
TShape oshape(ishape.ndim() - 1 + param.shape.ndim());
for (size_t i = 0; i < ishape.ndim() - 1; ++i) {
oshape[i] = ishape[i];
}
for (size_t i = 0; i < param.shape.ndim(); ++i) {
oshape[i + ishape.ndim() - 1] = param.shape[i];
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
if (param.get_prob) SHAPE_ASSIGN_CHECK(*out_attrs, 1, oshape);
return true;
}
inline bool SampleMultinomialOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
const SampleMultinomialParam& param = nnvm::get<SampleMultinomialParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), param.get_prob ? 2U : 1U);
int itype = (*in_attrs)[0];
if (itype == -1) return false;
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype);
if (param.get_prob) {
TYPE_ASSIGN_CHECK(*out_attrs, 1, itype);
}
return true;
}
struct SampleMultinomialKernel {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, index_t K, index_t M,
DType* dist, float* uniform, IType* out,
DType* prob) {
for (index_t j = 0; j < M; ++j) {
DType loc = static_cast<DType>(uniform[i*M + j]);
DType acc = 0;
bool found = false;
for (index_t k = 0; k < K; ++k) {
acc += dist[i*K + k];
if (acc > loc) {
found = true;
out[i*M + j] = static_cast<IType>(k);
if (prob != nullptr) prob[i*M + j] = logf(dist[i*K + k]);
break;
}
}
if (!found) {
out[i*M + j] = static_cast<IType>(K-1);
if (prob != nullptr) prob[i*M + j] = logf(dist[i*K + K - 1]);
}
}
}
};
template<typename xpu>
void SampleMultinomialForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
const SampleMultinomialParam& param = nnvm::get<SampleMultinomialParam>(attrs.parsed);
index_t K = inputs[0].shape_[inputs[0].ndim()-1];
index_t N = inputs[0].Size()/K;
index_t M = outputs[0].Size()/N;
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Random<xpu, float> *prnd = ctx.requested[0].get_random<xpu, float>(s);
Tensor<xpu, 1, float> uniform =
ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(N*M), s);
prnd->SampleUniform(&uniform, 0, 1);
Kernel<SampleMultinomialKernel, xpu>::Launch(
s, N, K, M, inputs[0].dptr<DType>(), uniform.dptr_, outputs[0].dptr<int>(),
param.get_prob ? outputs[1].dptr<DType>() : nullptr);
});
}
template<typename kernel, typename xpu>
void SampleMultinomialBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
if (req[0] == kNullOp) return;
index_t K = outputs[0].shape_[outputs[0].ndim()-1];
index_t N = outputs[0].Size()/K;
index_t M = inputs[0].Size()/N;
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
if (req[0] != kAddTo) {
Tensor<xpu, 1, DType> out = outputs[0].FlatTo1D<xpu, DType>(s);
out = 0;
}
Kernel<kernel, xpu>::Launch(
s, N, K, M, inputs[0].dptr<DType>(), inputs[1].dptr<DType>(),
inputs[2].dptr<int>(), outputs[0].dptr<DType>());
});
}
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_RANDOM_SAMPLE_MULTINOMIAL_OP_H_