StructInfo InferStructInfoNLLLoss()

in src/relax/op/nn/nn.cc [802:995]


StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) {
  if (call->args.size() < 2 || call->args.size() > 3) {
    ctx->ReportFatal(Diagnostic::Error(call) << "NLLLoss op should take 2 or 3 arguments");
  }

  const auto* pred_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
  const auto* tgt_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
  const TensorStructInfoNode* wgt_sinfo = nullptr;
  if (call->args.size() == 3) {
    wgt_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[2]);
    if (wgt_sinfo == nullptr) {
      ctx->ReportFatal(
          Diagnostic::Error(call)
          << "NLLLoss requires the argument weights to be Tensor. However, the given one is "
          << call->args[2]->struct_info_->GetTypeKey());
    }
  }

  if (pred_sinfo == nullptr) {
    ctx->ReportFatal(
        Diagnostic::Error(call)
        << "NLLLoss requires the argument preditions to be Tensor. However, the given one is "
        << call->args[0]->struct_info_->GetTypeKey());
  }
  if (tgt_sinfo == nullptr) {
    ctx->ReportFatal(
        Diagnostic::Error(call)
        << "NLLLoss requires the argument targets to be Tensor. However, the given one is "
        << call->args[1]->struct_info_->GetTypeKey());
  }

  // infer dtype, vdevice
  DataType output_dtype;
  Optional<VDevice> vdevice;
  if (wgt_sinfo != nullptr) {
    output_dtype = InferBinaryArithOpOutDtype(call, ctx, GetRef<TensorStructInfo>(pred_sinfo),
                                              GetRef<TensorStructInfo>(wgt_sinfo));
    vdevice = InferBinaryArithOpOutVDevice(call, ctx, GetRef<TensorStructInfo>(pred_sinfo),
                                           GetRef<TensorStructInfo>(wgt_sinfo));
  } else {
    output_dtype = pred_sinfo->dtype;
    vdevice = pred_sinfo->vdevice;
  }

  // the type of targets must be int/uint.
  if (!tgt_sinfo->IsUnknownDtype() && !tgt_sinfo->dtype.is_int() && !tgt_sinfo->dtype.is_uint()) {
    ctx->ReportFatal(
        Diagnostic::Error(call)
        << "NLLLoss expects the dtype of targets to be int/uint. However, the dtype of targets is "
        << tgt_sinfo->dtype);
  }

  // infer ndim
  int K = kUnknownNDim;  // k dim
  if (!pred_sinfo->IsUnknownNdim()) {
    if (pred_sinfo->ndim < 1) {
      ctx->ReportFatal(
          Diagnostic::Error(call)
          << "NLLLoss expects the ndim of predictions >= 1. However, the ndim of predictions is "
          << pred_sinfo->ndim);
    }
    K = pred_sinfo->ndim <= 2 ? 0 : pred_sinfo->ndim - 2;
  }
  if (!tgt_sinfo->IsUnknownNdim()) {
    int K_tgt = tgt_sinfo->ndim <= 1 ? 0 : tgt_sinfo->ndim - 1;
    if (K != kUnknownNDim && K != K_tgt) {
      ctx->ReportFatal(Diagnostic::Error(call)
                       << "NLLLoss expects number of dimensions K inferred from different "
                          "arguments to be equal. However, K from predictions is "
                       << K << " while K from targets is " << K_tgt);
    }
  }
  if (wgt_sinfo != nullptr && !wgt_sinfo->IsUnknownNdim() && wgt_sinfo->ndim != 1) {
    ctx->ReportFatal(Diagnostic::Error(call)
                     << "NLLLoss expects the ndim of weights == 1. However, the ndim of weights is "
                     << wgt_sinfo->ndim);
  }

  arith::Analyzer* analyzer = ctx->GetAnalyzer();
  Optional<PrimExpr> N;
  Optional<PrimExpr> C;
  Array<PrimExpr> output_shape;  // N, d1, d2, ..., dk

  Optional<Array<PrimExpr>> pred_shape_value;
  if (pred_sinfo->shape.defined()) {
    pred_shape_value = GetStructInfoAs<ShapeStructInfoNode>(pred_sinfo->shape.value())->values;
  }
  if (pred_shape_value.defined()) {
    if (pred_shape_value.value().size() == 1) {
      // (C,)
      ICHECK(pred_sinfo->ndim == 1);
      C = pred_shape_value.value()[0];
    } else {
      // (N, C, d1, d2, ..., dk)
      ICHECK(pred_shape_value.value().size() >= 2);
      ICHECK(pred_sinfo->ndim == static_cast<int>(pred_shape_value.value().size()));
      N = pred_shape_value.value()[0];
      C = pred_shape_value.value()[1];
      output_shape = Array<PrimExpr>();
      output_shape.push_back(N.value());
      for (size_t i = 2; i < pred_shape_value.value().size(); ++i) {
        output_shape.push_back(pred_shape_value.value()[i]);
      }
    }
  }

  Optional<Array<PrimExpr>> tgt_shape_value;
  if (tgt_sinfo->shape.defined()) {
    tgt_shape_value = GetStructInfoAs<ShapeStructInfoNode>(tgt_sinfo->shape.value())->values;
  }
  if (tgt_shape_value.defined()) {
    if (tgt_shape_value.value().empty()) {
      // ()
      ICHECK(tgt_sinfo->ndim == 0);
      if (N.defined()) {
        ctx->ReportFatal(Diagnostic::Error(call)
                         << "Shape mismatch for NLLLoss. Predictions shape is "
                            "(N, C, ...) while targets is a scalar");
      }
    } else {
      // (N,) or (N, d1, d2, ..., dk)
      // check N
      const PrimExpr& N_tgt = tgt_shape_value.value()[0];
      if (N.defined() && analyzer->CanProve(N.value() != N_tgt)) {
        ctx->ReportFatal(Diagnostic::Error(call)
                         << "NLLLoss expects minibatch size N inferred from different "
                            "arguments to be equal. However, N from predictions is "
                         << N << " while N from targets is " << N_tgt);
      }
      // only C case
      if (!N.defined() && C.defined()) {
        ctx->ReportFatal(Diagnostic::Error(call)
                         << "Shape mismatch for NLLLoss. Predictions shape is "
                            "(C,) while targets is not a scalar");
      }

      if (tgt_shape_value.value().size() == 1) {
        // (N,)
        ICHECK(tgt_sinfo->IsUnknownNdim() || tgt_sinfo->ndim == 1);
      } else {
        // (N, d1, d2, ..., dk)
        ICHECK(tgt_shape_value.value().size() >= 2);
        ICHECK(tgt_sinfo->IsUnknownNdim() ||
               tgt_sinfo->ndim == static_cast<int>(tgt_shape_value.value().size()));

        if (pred_shape_value.defined()) {
          // check (d1, d2, ..., dk)
          for (size_t i = 1; i < tgt_shape_value.value().size(); ++i) {
            if (analyzer->CanProve(output_shape[i] != tgt_shape_value.value()[i])) {
              ctx->ReportFatal(Diagnostic::Error(call)
                               << "Shape mismatch for NLLLoss. The prediction shape at this dim is "
                               << output_shape[i] << " while the target shape at this dim is "
                               << tgt_shape_value.value()[i]);
            }
          }
        }
      }
    }
  }

  if (wgt_sinfo != nullptr) {
    Optional<Array<PrimExpr>> wgt_shape_value;
    if (wgt_sinfo->shape.defined()) {
      wgt_shape_value = GetStructInfoAs<ShapeStructInfoNode>(wgt_sinfo->shape.value())->values;
    }
    if (wgt_shape_value.defined()) {
      ICHECK(wgt_shape_value.value().size() == 1);
      ICHECK(wgt_sinfo->IsUnknownNdim() || wgt_sinfo->ndim == 1);
      const PrimExpr& C_wgt = wgt_shape_value.value()[0];
      if (C.defined() && analyzer->CanProve(C.value() != C_wgt)) {
        ctx->ReportFatal(Diagnostic::Error(call)
                         << "NLLLoss expects number of classes C inferred from different "
                            "arguments to be equal. However, C from predictions is "
                         << C << " while C from weights is " << C_wgt);
      }
    }
  }

  const auto* attrs = call->attrs.as<NLLLossAttrs>();
  String reduction = attrs->reduction;

  if (reduction == "none") {
    // () or (N,) or (N, d1, d2, ..., dk)
    if (pred_sinfo->shape.as<ShapeExprNode>()) {
      return TensorStructInfo(ShapeExpr(output_shape), output_dtype, vdevice);
    } else {
      int output_ndim = pred_sinfo->ndim == kUnknownNDim ? kUnknownNDim : pred_sinfo->ndim - 1;
      return TensorStructInfo(output_dtype, /*ndim=*/output_ndim, vdevice);
    }
  } else {
    // sum or mean. output is scalar
    return TensorStructInfo(/*shape=*/ShapeExpr(Array<PrimExpr>()), output_dtype, vdevice);
  }
}