void SelectAlgo()

in src/operator/nn/cudnn/cudnn_convolution-inl.h [614:853]


  void SelectAlgo(const RunContext& rctx,
                  const std::vector<TShape>& in_shape,
                  const std::vector<TShape>& out_shape,
                  cudnnDataType_t cudnn_forward_compute_type,
                  cudnnDataType_t cudnn_backward_compute_type) {
    if (!CuDNNConvAlgoReg::Get()->Find(param_, in_shape, out_shape, dtype_,
                                       cudnn_forward_compute_type, cudnn_backward_compute_type,
                                       SMArch(rctx.ctx.dev_id), add_to_weight_,
                                       &forward_algo_, &back_algo_, &back_algo_w_)) {
      mshadow::Stream<gpu> *s = rctx.get_stream<gpu>();
      CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
      size_t workspace_byte = static_cast<size_t>(param_.workspace * sizeof(DType));
      #if CUDNN_MAJOR >= 7
      // Starting with cuDNNv7, the algo number returned by *Get*() is not the entire
      // story: the notion of whether the algo ran in Tensor Core mode is not known.
      // Since we want to report the Tensor Core mode in the verbose output, we switch
      // to using the new *Get*_v7() call.  Since the function signature of *Get*_v7() matches
      // that of *Find*(), we can unify the find-vs-get logic by using function pointers.

      // Forward Algorithm Find/Get() v7
      std::vector<cudnnConvolutionFwdAlgoPerf_t> fwd_results(MaxForwardAlgos(s->dnn_handle_));
      int actual_fwd_algos = 0;
      auto fwd_algo_discoverer =
        param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionForwardAlgorithm_v7
                                                : cudnnFindConvolutionForwardAlgorithm;
      CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_,
                                        in_desc_,
                                        filter_desc_,
                                        forward_conv_desc_,
                                        out_desc_,
                                        fwd_results.size(),
                                        &actual_fwd_algos,
                                        fwd_results.data()));
      fwd_results.resize(actual_fwd_algos);
      AlgoFinalSelect<cudnnConvolutionFwdAlgoPerf_t,
                      cudnnConvolutionFwdAlgo_t>(fwd_results, "forward",
                                                 workspace_byte, &forward_algo_);

      // Backprop-to-Filter Algorithm Find/Get() v7
      auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_);
      std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filt_results(max_bwd_filt_algos);
      int actual_bwd_filter_algos = 0;
      // In cudnn v7.1.4, find() returned wgrad algos that could fail for large c if we
      // were summing into the output (i.e. beta != 0).  Get() returned OK algos though.
      auto bwd_filter_algo_discoverer =
        param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardFilterAlgorithm_v7
                                                : cudnnFindConvolutionBackwardFilterAlgorithm;
      CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_,
                                               in_desc_,
                                               out_desc_,
                                               back_conv_desc_w_,
                                               filter_desc_,
                                               bwd_filt_results.size(),
                                               &actual_bwd_filter_algos,
                                               bwd_filt_results.data()));
      bwd_filt_results.resize(actual_bwd_filter_algos);
      AlgoFinalSelect<cudnnConvolutionBwdFilterAlgoPerf_t,
                      cudnnConvolutionBwdFilterAlgo_t>(bwd_filt_results, "backprop-to-filter",
                                   workspace_byte, &back_algo_w_);

      // Backprop-to-Data Algorithm Find/Get() v7
      auto max_bwd_data_algos = MaxBackwardDataAlgos(s->dnn_handle_);
      std::vector<cudnnConvolutionBwdDataAlgoPerf_t> bwd_data_results(max_bwd_data_algos);
      int actual_bwd_data_algos = 0;
      auto bwd_data_algo_discoverer =
        param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardDataAlgorithm_v7
                                                : cudnnFindConvolutionBackwardDataAlgorithm;
      CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_,
                                             filter_desc_,
                                             out_desc_,
                                             back_conv_desc_,
                                             in_desc_,
                                             bwd_data_results.size(),
                                             &actual_bwd_data_algos,
                                             bwd_data_results.data()));
      bwd_data_results.resize(actual_bwd_data_algos);
      AlgoFinalSelect<cudnnConvolutionBwdDataAlgoPerf_t,
                      cudnnConvolutionBwdDataAlgo_t>(bwd_data_results, "backprop-to-data",
                                    workspace_byte, &back_algo_);
      #else
      // CUDNN_MAJOR < 7
      const int kMaxAlgos = 10;
      int nalgo = kMaxAlgos;
      int i = 0;
      size_t min_memory_needs = 0;
      // Forward Algorithm Find/Get, v6 and earlier
      if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) {
        // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is
        // supported.  Hard-coded this since the algo find() or get() throws an FPE.
        forward_algo_.Set(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, false);
      } else if (!param_.cudnn_tune.value()) {
        cudnnConvolutionFwdAlgo_t fastest_fwd_algo;
        CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_,
                                                 in_desc_,
                                                 filter_desc_,
                                                 forward_conv_desc_,
                                                 out_desc_,
                                                 CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
                                                 workspace_byte,
                                                 &fastest_fwd_algo));
        forward_algo_.Set(fastest_fwd_algo, false);
      } else {
        cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos];
        CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_,
                                                        in_desc_,
                                                        filter_desc_,
                                                        forward_conv_desc_,
                                                        out_desc_,
                                                        kMaxAlgos,
                                                        &nalgo,
                                                        fwd_algo));
        i = 0;
        while (i < nalgo
               && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS
                   || (param_.cudnn_tune.value() == conv::kLimited
                       && fwd_algo[i].memory > workspace_byte))) {
          ++i;
          min_memory_needs =
            (i == 0) ? fwd_algo[i].memory : std::min(min_memory_needs, fwd_algo[i].memory);
        }
        if (i == nalgo) {
          LOG(FATAL) << nalgo << " forward algorithms with minimum memory requirement "
                     << min_memory_needs << " bytes have been tried. Workspace size is set to "
                     << workspace_byte << " bytes, please consider reducing the batch/model size, "
                     << "or increasing workspace size.";
        } else {
          forward_algo_.Set(fwd_algo[i].algo, false);
        }
      }
      // Backprop-to-Filter Algorithm Find/Get, v6 and earlier
      if (!param_.cudnn_tune.value()) {
        cudnnConvolutionBwdFilterAlgo_t fastest_bwd_filt_algo;
        CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
                                          in_desc_,
                                          out_desc_,
                                          back_conv_desc_w_,
                                          filter_desc_,
                                          CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
                                          workspace_byte,
                                          &fastest_bwd_filt_algo));
        back_algo_w_.Set(fastest_bwd_filt_algo, false);
      } else {
        cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos];
        CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
                                                               in_desc_,
                                                               out_desc_,
                                                               back_conv_desc_w_,
                                                               filter_desc_,
                                                               kMaxAlgos,
                                                               &nalgo,
                                                               bwd_filter_algo));
        i = 0;
        while (i < nalgo
               && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS
                   || (param_.cudnn_tune.value() == conv::kLimited
                       && bwd_filter_algo[i].memory > workspace_byte))) {
          ++i;
          min_memory_needs = (i == 0) ?
                             bwd_filter_algo[i].memory :
                             std::min(min_memory_needs, bwd_filter_algo[i].memory);
        }
        if (i == nalgo) {
          LOG(FATAL) << nalgo << " backward filter algorithms with minimum memory requirement "
                     << min_memory_needs << " bytes have been tried. Workspace size is set to "
                     << workspace_byte << " bytes, please consider reducing the batch/model size, "
                     << "or increasing workspace size.";
        } else {
          back_algo_w_.Set(bwd_filter_algo[i].algo, false);
        }
      }
      // Backprop-to-Data Algorithm Get(), v6 and earlier
      if (!param_.cudnn_tune.value()) {
        cudnnConvolutionBwdDataAlgo_t fastest_bwd_data_algo;
        CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_,
                                            filter_desc_,
                                            out_desc_,
                                            back_conv_desc_,
                                            in_desc_,
                                            CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
                                            workspace_byte,
                                            &fastest_bwd_data_algo));
        back_algo_.Set(fastest_bwd_data_algo, false);
      } else {
        cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos];
        CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_,
                                                             filter_desc_,
                                                             out_desc_,
                                                             back_conv_desc_,
                                                             in_desc_,
                                                             kMaxAlgos,
                                                             &nalgo,
                                                             bwd_data_algo));
        i = 0;
        while (i < nalgo
               && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS
                   || (param_.cudnn_tune.value() == conv::kLimited
                       && bwd_data_algo[i].memory > workspace_byte))) {
          ++i;
          min_memory_needs = (i == 0) ?
                             bwd_data_algo[i].memory :
                             std::min(min_memory_needs, bwd_data_algo[i].memory);
        }
        if (i == nalgo) {
          LOG(FATAL) << nalgo << " backward data algorithms with minimum memory requirement "
                     << min_memory_needs << " bytes have been tried. Workspace size is set to "
                     << workspace_byte << " bytes, please consider reducing the batch/model size, "
                     << "or increasing workspace size.";
        } else {
          back_algo_.Set(bwd_data_algo[i].algo, false);
        }
      }
      #endif  // CUDNN_MAJOR < 7

      // Fix for issue #11241
      int cudnn_find_issue_max_features = 64 * 1024;
      if (add_to_weight_ && Features(in_shape[conv::kData]) >= cudnn_find_issue_max_features) {
        this->back_algo_w_.Set(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true);
      }

      // An algo specification by the user may be cached here, but another
      // convolution will match only if identically specified.
      // We're caching results of *Get* as well as *Find*, but these records
      // will be held distinctly because param_.cudnn_tune is part of the key.
      CuDNNConvAlgoReg::Get()->Register(param_, in_shape, out_shape, dtype_,
                                        cudnn_forward_compute_type,
                                        cudnn_backward_compute_type,
                                        SMArch(rctx.ctx.dev_id), this->add_to_weight_,
                                        this->forward_algo_,
                                        this->back_algo_, this->back_algo_w_);
    }
    // If we're allowing Tensor Core variants of the algos to be considered in
    // *Find*() or *Get*(), but a non-Tensor-Core algo variant is the fastest,
    // we must change the descriptor to preclude Tensor Core.  Simplest is to
    // once again set the mathType in all cases.
    #if CUDNN_MAJOR >= 7
    CUDNN_CALL(cudnnSetConvolutionMathType(forward_conv_desc_, forward_algo_.MathType()));
    CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_, back_algo_.MathType()));
    CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_w_, back_algo_w_.MathType()));
    #endif
  }