bool InferShape()

in src/operator/deconvolution-inl.h [428:618]


  bool InferShape(std::vector<TShape> *in_shape,
                  std::vector<TShape> *out_shape,
                  std::vector<TShape> *aux_shape) const override {
#if MXNET_USE_CUDNN == 0
    if (param_.kernel.ndim() != 2) {
      LOG(FATAL) << "If not using CUDNN only 2D-Deconvolution is supported";
      return false;
    }
#endif  // CUDNN

    using namespace mshadow;
    if (!param_.no_bias) {
      CHECK_EQ(in_shape->size(), 3U) << "Input:[data, weight, bias]";
    } else {
      CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]";
    }
    out_shape->resize(1, TShape());
    const TShape &dshape = (*in_shape)[deconv::kData];
    if (dshape.ndim() ==  0) return false;

    if (param_.kernel.ndim() == 1) {
      // 1d conv
      CHECK_EQ(dshape.ndim(), 3U) << "Input data should be 3D in batch-num_filter-x";
      Shape<3> dshape_ncw = ConvertLayout(dshape.get<3>(), param_.layout.value(), kNCW);
      Shape<3> wshape = Shape3(dshape_ncw[1], param_.num_filter / param_.num_group,
                               param_.kernel[0]);
      wshape = ConvertLayout(wshape, kNCW, param_.layout.value());
      SHAPE_ASSIGN_CHECK(*in_shape, deconv::kWeight, wshape);
      if (!param_.no_bias) {
        SHAPE_ASSIGN_CHECK(*in_shape, deconv::kBias, Shape1(param_.num_filter));
      }

      const index_t dilated_ksize_x = param_.DilatedKernelSize(0);

      index_t o_pad[1];
      index_t o_adj[1];
      param_.InferPad(dshape_ncw, o_pad, o_adj);

      CHECK_EQ(dshape_ncw[1] % param_.num_group, 0U) \
        << "input num_filter must divide group size";
      CHECK_EQ(param_.num_filter % param_.num_group, 0U) \
        << "output num_filter must divide group size";
      CHECK_GT(param_.kernel.Size(), 0U) \
        << "incorrect kernel size: " << param_.kernel;
      CHECK_GT(param_.stride.Size(), 0U) \
        << "incorrect stride size: " << param_.stride;
      CHECK_GT(param_.dilate.Size(), 0U) \
        << "incorrect dilate size: " << param_.dilate;

      CHECK_GE(param_.stride[0]-1, o_adj[0]) << "adj(x) must be samller than stride[0]";

      Shape<3> oshape;
      oshape[0] = dshape_ncw[0];
      oshape[1] = param_.num_filter;
      oshape[2] = param_.stride[0] * (dshape_ncw[2] - 1) +
        dilated_ksize_x - 2 * o_pad[0] + o_adj[0];

      if (param_.target_shape[0] > 0) {
        CHECK_EQ(param_.target_shape[0], oshape[2]) \
          << "param_.target_shape[0] was not reasonable, please set it carefully";
      }

      SHAPE_ASSIGN_CHECK(*out_shape, 0, ConvertLayout(oshape, kNCW, param_.layout.value()));

      return true;
    } else if (param_.kernel.ndim() == 2) {
      // 2d conv
      CHECK_EQ(dshape.ndim(), 4U) \
        << "Input data should be 4D in batch-num_filter-y-x";
      Shape<4> dshape_nchw = ConvertLayout(dshape.get<4>(), param_.layout.value(), kNCHW);
      Shape<4> wshape = Shape4(dshape_nchw[1],
                               param_.num_filter / param_.num_group,
                               param_.kernel[0], param_.kernel[1]);
      wshape = ConvertLayout(wshape, kNCHW, param_.layout.value());
      SHAPE_ASSIGN_CHECK(*in_shape, deconv::kWeight, wshape);
      if (!param_.no_bias) {
        SHAPE_ASSIGN_CHECK(*in_shape, deconv::kBias, Shape1(param_.num_filter));
      }

      const index_t dilated_ksize_y = param_.DilatedKernelSize(0);
      const index_t dilated_ksize_x = param_.DilatedKernelSize(1);

      index_t o_pad[2];
      index_t o_adj[2];
      param_.InferPad(dshape_nchw, o_pad, o_adj);

      CHECK_EQ(dshape_nchw[1] % param_.num_group, 0U) \
        << "input num_filter must divide group size";
      CHECK_EQ(param_.num_filter % param_.num_group, 0U) \
        << "output num_filter must divide group size";
      CHECK_GT(param_.kernel.Size(), 0U) \
        << "incorrect kernel size: " << param_.kernel;
      CHECK_GT(param_.stride.Size(), 0U) \
        << "incorrect stride size: " << param_.stride;
      CHECK_GT(param_.dilate.Size(), 0U) \
          << "incorrect dilate size: " << param_.dilate;

      CHECK_GE(param_.stride[0]-1, o_adj[0]) << "adj(y) must be samller than stride[0]";
      CHECK_GE(param_.stride[1]-1, o_adj[1]) << "adj(x) must be samller than stride[1]";

      Shape<4> oshape;
      oshape[0] = dshape_nchw[0];
      oshape[1] = param_.num_filter;
      oshape[2] = param_.stride[0] * (dshape_nchw[2] - 1) +
        dilated_ksize_y - 2 * o_pad[0] + o_adj[0];
      oshape[3] = param_.stride[1] * (dshape_nchw[3] - 1) +
        dilated_ksize_x - 2 * o_pad[1] + o_adj[1];

      if (param_.target_shape[0] > 0) {
        CHECK_EQ(param_.target_shape[0], oshape[2]) \
          << "param_.target_shape[0] was not reasonable, please set it carefully";
      }
      if (param_.target_shape[1] > 0) {
        CHECK_EQ(param_.target_shape[1], oshape[3]) \
          << "param_.target_shape[1] was not reasonable, please set it carefully";
      }

      SHAPE_ASSIGN_CHECK(*out_shape, 0, ConvertLayout(oshape, kNCHW, param_.layout.value()));

      return true;
    } else if (param_.kernel.ndim() == 3) {
      // 3d conv
      CHECK_EQ(dshape.ndim(), 5U) \
        << "Input data should be 5D in batch-num_filter-depth-y-x";
      Shape<5> dshape_ncdhw = ConvertLayout(dshape.get<5>(), param_.layout.value(), kNCDHW);
      Shape<5> wshape = Shape5(dshape_ncdhw[1], param_.num_filter / param_.num_group,
                               param_.kernel[0], param_.kernel[1], param_.kernel[2]);
      wshape = ConvertLayout(wshape, kNCDHW, param_.layout.value());
      SHAPE_ASSIGN_CHECK(*in_shape, deconv::kWeight, wshape);
      if (!param_.no_bias) {
        SHAPE_ASSIGN_CHECK(*in_shape, deconv::kBias, Shape1(param_.num_filter));
      }

      // Note: 3D dilation currently not supported.
      // Calculations below done to preserve symmetry with 1D/2D code.
      const index_t dilated_ksize_d = param_.DilatedKernelSize(0);
      const index_t dilated_ksize_y = param_.DilatedKernelSize(1);
      const index_t dilated_ksize_x = param_.DilatedKernelSize(2);

      index_t o_pad[3];
      index_t o_adj[3];
      param_.InferPad(dshape_ncdhw, o_pad, o_adj);

      CHECK_EQ(dshape_ncdhw[1] % param_.num_group, 0U) \
        << "input num_filter must divide group size";
      CHECK_EQ(param_.num_filter % param_.num_group, 0U) \
        << "output num_filter must divide group size";
      CHECK_GT(param_.kernel.Size(), 0U) \
        << "incorrect kernel size: " << param_.kernel;
      CHECK_GT(param_.stride.Size(), 0U) \
        << "incorrect stride size: " << param_.stride;
      CHECK_GT(param_.dilate.Size(), 0U) \
        << "incorrect dilate size: " << param_.dilate;
      CHECK_EQ(param_.dilate.Size(), 1U)
        << "Dilate is not supported in 3d deconvolution";

      CHECK_GE(param_.stride[0]-1, o_adj[0]) << "adj(d) must be samller than stride[0]";
      CHECK_GE(param_.stride[1]-1, o_adj[1]) << "adj(y) must be samller than stride[1]";
      CHECK_GE(param_.stride[2]-1, o_adj[2]) << "adj(x) must be samller than stride[2]";

      Shape<5> oshape;
      oshape[0] = dshape_ncdhw[0];
      oshape[1] = param_.num_filter;
      oshape[2] = param_.stride[0] * (dshape_ncdhw[2] - 1) +
        dilated_ksize_d - 2 * o_pad[0] + o_adj[0];
      oshape[3] = param_.stride[1] * (dshape_ncdhw[3] - 1) +
        dilated_ksize_y - 2 * o_pad[1] + o_adj[1];
      oshape[4] = param_.stride[2] * (dshape_ncdhw[4] - 1) +
        dilated_ksize_x - 2 * o_pad[2] + o_adj[2];

      if (param_.target_shape[0] > 0) {
        CHECK_EQ(param_.target_shape[0], oshape[2]) \
          << "param_.target_shape[0] was not reasonable, please it carefully";
      }
      if (param_.target_shape[1] > 0) {
        CHECK_EQ(param_.target_shape[1], oshape[3]) \
          << "param_.target_shape[1] was not reasonable, please set it carefully";
      }
      if (param_.target_shape[2] > 0) {
        CHECK_EQ(param_.target_shape[2], oshape[4]) \
          << "param_.target_shape[2] was not reasonable, please set it carefully";
      }

      SHAPE_ASSIGN_CHECK(*out_shape, 0, ConvertLayout(oshape, kNCDHW, param_.layout.value()));

      return true;
    } else {
      LOG(FATAL) << "Unknown convolution type";
      return false;
    }
  }