in src/operator/numpy/random/np_gamma_op.h [258:435]
void NumpyGammaForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
const NumpyGammaParam& param = nnvm::get<NumpyGammaParam>(attrs.parsed);
CHECK_EQ(outputs.size(), 1);
Stream<xpu>* s = ctx.get_stream<xpu>();
// Generate base random number.
Random<xpu, FType>* prnd = ctx.requested[0].get_random<xpu, FType>(s);
index_t output_len = outputs[0].Size();
Tensor<xpu, 1, FType> random_tensor =
ctx.requested[1].get_space_typed<xpu, 1, FType>(Shape1(output_len * 2 * M + 1), s);
Tensor<xpu, 1, FType> uniform_tensor = random_tensor.Slice(0, output_len * M);
Tensor<xpu, 1, FType> normal_tensor = random_tensor.Slice(output_len * M, output_len * 2 * M);
prnd->SampleUniform(&uniform_tensor, 0, 1);
prnd->SampleGaussian(&normal_tensor, 0, 1);
mxnet::TShape new_lshape, new_hshape, new_oshape;
FType failure_indicator = 1.0;
Tensor<xpu, 1, FType> failure_indic_workspace =
random_tensor.Slice(output_len * 2 * M, output_len * 2 * M + 1);
FType* failure_indicator_device = failure_indic_workspace.dptr_;
// [scalar scalar] case
if (inputs.size() == 0U) {
if (param.shape.value() <= 0) {
CHECK(false) << "ValueError: expect shape > 0";
}
if (param.scale.value() <= 0) {
CHECK(false) << "ValueError: expect scale > 0";
}
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
bool in_resample_stage = false;
do {
if (in_resample_stage) {
prnd->SampleUniform(&uniform_tensor, 0, 1);
prnd->SampleGaussian(&normal_tensor, 0, 1);
}
Kernel<gamma_two_scalar_kernel<OType, FType>, xpu>::Launch(
s,
outputs[0].Size(),
param.shape.value(),
param.scale.value(),
uniform_tensor.dptr_,
normal_tensor.dptr_,
outputs[0].dptr<OType>(),
in_resample_stage ? failure_indicator_device : nullptr);
failure_indicator = 1.0;
Kernel<CheckSuccessKernel<OType, FType>, xpu>::Launch(
s, outputs[0].Size(), outputs[0].dptr<OType>(), failure_indicator_device);
_copy<xpu>(s, &failure_indicator, failure_indicator_device);
in_resample_stage = true;
} while (failure_indicator < 0);
});
} else if (inputs.size() == 1U) {
// [scalar tensor], [tensor scalar] case
int ndim = FillShape(inputs[0].shape_,
inputs[0].shape_,
outputs[0].shape_,
&new_lshape,
&new_lshape,
&new_oshape);
int scalar_pos;
float scalar_value;
if (param.shape.has_value()) {
scalar_pos = 0;
scalar_value = param.shape.value();
if (scalar_value <= 0) {
LOG(FATAL) << "ValueError: expect shape > 0";
}
} else {
scalar_pos = 1;
scalar_value = param.scale.value();
if (scalar_value <= 0) {
LOG(FATAL) << "ValueError: expect scale > 0";
}
}
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, IType, {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
mshadow::Shape<NDim> stride = calc_stride(new_lshape.get<NDim>());
bool in_resample_stage = false;
do {
if (in_resample_stage) {
prnd->SampleUniform(&uniform_tensor, 0, 1);
prnd->SampleGaussian(&normal_tensor, 0, 1);
}
Kernel<set_one, xpu>::Launch(s, 1, failure_indicator_device);
Kernel<gamma_one_scalar_kernel<NDim, IType, OType, FType>, xpu>::Launch(
s,
outputs[0].Size(),
scalar_pos,
stride,
oshape,
inputs[0].dptr<IType>(),
scalar_value,
uniform_tensor.dptr_,
normal_tensor.dptr_,
outputs[0].dptr<OType>(),
failure_indicator_device,
in_resample_stage);
// We reuse `failure_indicator` for parameter check.
failure_indicator = 1.0;
if (!in_resample_stage) {
// Only check parameter validity in the first trial.
_copy<xpu>(s, &failure_indicator, failure_indicator_device);
if (failure_indicator < 0) {
if (param.shape.has_value()) {
// Problematic tensor contains `scale`.
LOG(FATAL) << "ValueError: expect scale > 0";
} else {
// Problematic tensor contains `shape`.
LOG(FATAL) << "ValueError: expect shape > 0";
}
}
}
failure_indicator = 1.0;
Kernel<CheckSuccessKernel<OType, FType>, xpu>::Launch(
s, outputs[0].Size(), outputs[0].dptr<OType>(), failure_indicator_device);
_copy<xpu>(s, &failure_indicator, failure_indicator_device);
in_resample_stage = true;
} while (failure_indicator < 0);
});
});
});
} else if (inputs.size() == 2U) {
// [tensor tensor] case
int ndim = FillShape(inputs[0].shape_,
inputs[1].shape_,
outputs[0].shape_,
&new_lshape,
&new_hshape,
&new_oshape);
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, IType, {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
mshadow::Shape<NDim> lstride = calc_stride(new_lshape.get<NDim>());
mshadow::Shape<NDim> hstride = calc_stride(new_hshape.get<NDim>());
bool in_resample_stage = false;
do {
if (in_resample_stage) {
prnd->SampleUniform(&uniform_tensor, 0, 1);
prnd->SampleGaussian(&normal_tensor, 0, 1);
}
Kernel<set_one, xpu>::Launch(s, 1, failure_indicator_device);
failure_indicator = 1.0;
Kernel<gamma_kernel<NDim, IType, OType, FType>, xpu>::Launch(s,
outputs[0].Size(),
lstride,
hstride,
oshape,
inputs[0].dptr<IType>(),
inputs[1].dptr<IType>(),
uniform_tensor.dptr_,
normal_tensor.dptr_,
outputs[0].dptr<OType>(),
failure_indicator_device,
in_resample_stage);
if (!in_resample_stage) {
_copy<xpu>(s, &failure_indicator, failure_indicator_device);
if (failure_indicator < 0) {
LOG(FATAL) << "ValueError: expect shape and value > 0";
}
}
failure_indicator = 1.0;
Kernel<CheckSuccessKernel<OType, FType>, xpu>::Launch(
s, outputs[0].Size(), outputs[0].dptr<OType>(), failure_indicator_device);
_copy<xpu>(s, &failure_indicator, failure_indicator_device);
in_resample_stage = true;
} while (failure_indicator < 0);
});
});
});
}
}