void NumpyGammaForward()

in src/operator/numpy/random/np_gamma_op.h [258:435]


void NumpyGammaForward(const nnvm::NodeAttrs& attrs,
                       const OpContext& ctx,
                       const std::vector<TBlob>& inputs,
                       const std::vector<OpReqType>& req,
                       const std::vector<TBlob>& outputs) {
  using namespace mshadow;
  using namespace mxnet_op;
  const NumpyGammaParam& param = nnvm::get<NumpyGammaParam>(attrs.parsed);
  CHECK_EQ(outputs.size(), 1);
  Stream<xpu>* s = ctx.get_stream<xpu>();
  // Generate base random number.
  Random<xpu, FType>* prnd = ctx.requested[0].get_random<xpu, FType>(s);
  index_t output_len       = outputs[0].Size();
  Tensor<xpu, 1, FType> random_tensor =
      ctx.requested[1].get_space_typed<xpu, 1, FType>(Shape1(output_len * 2 * M + 1), s);
  Tensor<xpu, 1, FType> uniform_tensor = random_tensor.Slice(0, output_len * M);
  Tensor<xpu, 1, FType> normal_tensor  = random_tensor.Slice(output_len * M, output_len * 2 * M);
  prnd->SampleUniform(&uniform_tensor, 0, 1);
  prnd->SampleGaussian(&normal_tensor, 0, 1);
  mxnet::TShape new_lshape, new_hshape, new_oshape;
  FType failure_indicator = 1.0;
  Tensor<xpu, 1, FType> failure_indic_workspace =
      random_tensor.Slice(output_len * 2 * M, output_len * 2 * M + 1);
  FType* failure_indicator_device = failure_indic_workspace.dptr_;
  // [scalar scalar] case
  if (inputs.size() == 0U) {
    if (param.shape.value() <= 0) {
      CHECK(false) << "ValueError: expect shape > 0";
    }
    if (param.scale.value() <= 0) {
      CHECK(false) << "ValueError: expect scale > 0";
    }
    MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
      bool in_resample_stage = false;
      do {
        if (in_resample_stage) {
          prnd->SampleUniform(&uniform_tensor, 0, 1);
          prnd->SampleGaussian(&normal_tensor, 0, 1);
        }
        Kernel<gamma_two_scalar_kernel<OType, FType>, xpu>::Launch(
            s,
            outputs[0].Size(),
            param.shape.value(),
            param.scale.value(),
            uniform_tensor.dptr_,
            normal_tensor.dptr_,
            outputs[0].dptr<OType>(),
            in_resample_stage ? failure_indicator_device : nullptr);
        failure_indicator = 1.0;
        Kernel<CheckSuccessKernel<OType, FType>, xpu>::Launch(
            s, outputs[0].Size(), outputs[0].dptr<OType>(), failure_indicator_device);
        _copy<xpu>(s, &failure_indicator, failure_indicator_device);
        in_resample_stage = true;
      } while (failure_indicator < 0);
    });
  } else if (inputs.size() == 1U) {
    // [scalar tensor], [tensor scalar] case
    int ndim = FillShape(inputs[0].shape_,
                         inputs[0].shape_,
                         outputs[0].shape_,
                         &new_lshape,
                         &new_lshape,
                         &new_oshape);
    int scalar_pos;
    float scalar_value;
    if (param.shape.has_value()) {
      scalar_pos   = 0;
      scalar_value = param.shape.value();
      if (scalar_value <= 0) {
        LOG(FATAL) << "ValueError: expect shape > 0";
      }
    } else {
      scalar_pos   = 1;
      scalar_value = param.scale.value();
      if (scalar_value <= 0) {
        LOG(FATAL) << "ValueError: expect scale > 0";
      }
    }
    MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, IType, {
      MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
        BROADCAST_NDIM_SWITCH(ndim, NDim, {
          mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
          mshadow::Shape<NDim> stride = calc_stride(new_lshape.get<NDim>());
          bool in_resample_stage      = false;
          do {
            if (in_resample_stage) {
              prnd->SampleUniform(&uniform_tensor, 0, 1);
              prnd->SampleGaussian(&normal_tensor, 0, 1);
            }
            Kernel<set_one, xpu>::Launch(s, 1, failure_indicator_device);
            Kernel<gamma_one_scalar_kernel<NDim, IType, OType, FType>, xpu>::Launch(
                s,
                outputs[0].Size(),
                scalar_pos,
                stride,
                oshape,
                inputs[0].dptr<IType>(),
                scalar_value,
                uniform_tensor.dptr_,
                normal_tensor.dptr_,
                outputs[0].dptr<OType>(),
                failure_indicator_device,
                in_resample_stage);
            // We reuse `failure_indicator` for parameter check.
            failure_indicator = 1.0;
            if (!in_resample_stage) {
              // Only check parameter validity in the first trial.
              _copy<xpu>(s, &failure_indicator, failure_indicator_device);
              if (failure_indicator < 0) {
                if (param.shape.has_value()) {
                  // Problematic tensor contains `scale`.
                  LOG(FATAL) << "ValueError: expect scale > 0";
                } else {
                  // Problematic tensor contains `shape`.
                  LOG(FATAL) << "ValueError: expect shape > 0";
                }
              }
            }
            failure_indicator = 1.0;
            Kernel<CheckSuccessKernel<OType, FType>, xpu>::Launch(
                s, outputs[0].Size(), outputs[0].dptr<OType>(), failure_indicator_device);
            _copy<xpu>(s, &failure_indicator, failure_indicator_device);
            in_resample_stage = true;
          } while (failure_indicator < 0);
        });
      });
    });
  } else if (inputs.size() == 2U) {
    // [tensor tensor] case
    int ndim = FillShape(inputs[0].shape_,
                         inputs[1].shape_,
                         outputs[0].shape_,
                         &new_lshape,
                         &new_hshape,
                         &new_oshape);
    MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, IType, {
      MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
        BROADCAST_NDIM_SWITCH(ndim, NDim, {
          mshadow::Shape<NDim> oshape  = new_oshape.get<NDim>();
          mshadow::Shape<NDim> lstride = calc_stride(new_lshape.get<NDim>());
          mshadow::Shape<NDim> hstride = calc_stride(new_hshape.get<NDim>());
          bool in_resample_stage       = false;
          do {
            if (in_resample_stage) {
              prnd->SampleUniform(&uniform_tensor, 0, 1);
              prnd->SampleGaussian(&normal_tensor, 0, 1);
            }
            Kernel<set_one, xpu>::Launch(s, 1, failure_indicator_device);
            failure_indicator = 1.0;
            Kernel<gamma_kernel<NDim, IType, OType, FType>, xpu>::Launch(s,
                                                                         outputs[0].Size(),
                                                                         lstride,
                                                                         hstride,
                                                                         oshape,
                                                                         inputs[0].dptr<IType>(),
                                                                         inputs[1].dptr<IType>(),
                                                                         uniform_tensor.dptr_,
                                                                         normal_tensor.dptr_,
                                                                         outputs[0].dptr<OType>(),
                                                                         failure_indicator_device,
                                                                         in_resample_stage);
            if (!in_resample_stage) {
              _copy<xpu>(s, &failure_indicator, failure_indicator_device);
              if (failure_indicator < 0) {
                LOG(FATAL) << "ValueError: expect shape and value > 0";
              }
            }
            failure_indicator = 1.0;
            Kernel<CheckSuccessKernel<OType, FType>, xpu>::Launch(
                s, outputs[0].Size(), outputs[0].dptr<OType>(), failure_indicator_device);
            _copy<xpu>(s, &failure_indicator, failure_indicator_device);
            in_resample_stage = true;
          } while (failure_indicator < 0);
        });
      });
    });
  }
}