in src/relax/op/nn/nn.cc [802:995]
StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) {
if (call->args.size() < 2 || call->args.size() > 3) {
ctx->ReportFatal(Diagnostic::Error(call) << "NLLLoss op should take 2 or 3 arguments");
}
const auto* pred_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
const auto* tgt_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
const TensorStructInfoNode* wgt_sinfo = nullptr;
if (call->args.size() == 3) {
wgt_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[2]);
if (wgt_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "NLLLoss requires the argument weights to be Tensor. However, the given one is "
<< call->args[2]->struct_info_->GetTypeKey());
}
}
if (pred_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "NLLLoss requires the argument preditions to be Tensor. However, the given one is "
<< call->args[0]->struct_info_->GetTypeKey());
}
if (tgt_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "NLLLoss requires the argument targets to be Tensor. However, the given one is "
<< call->args[1]->struct_info_->GetTypeKey());
}
// infer dtype, vdevice
DataType output_dtype;
Optional<VDevice> vdevice;
if (wgt_sinfo != nullptr) {
output_dtype = InferBinaryArithOpOutDtype(call, ctx, GetRef<TensorStructInfo>(pred_sinfo),
GetRef<TensorStructInfo>(wgt_sinfo));
vdevice = InferBinaryArithOpOutVDevice(call, ctx, GetRef<TensorStructInfo>(pred_sinfo),
GetRef<TensorStructInfo>(wgt_sinfo));
} else {
output_dtype = pred_sinfo->dtype;
vdevice = pred_sinfo->vdevice;
}
// the type of targets must be int/uint.
if (!tgt_sinfo->IsUnknownDtype() && !tgt_sinfo->dtype.is_int() && !tgt_sinfo->dtype.is_uint()) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "NLLLoss expects the dtype of targets to be int/uint. However, the dtype of targets is "
<< tgt_sinfo->dtype);
}
// infer ndim
int K = kUnknownNDim; // k dim
if (!pred_sinfo->IsUnknownNdim()) {
if (pred_sinfo->ndim < 1) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "NLLLoss expects the ndim of predictions >= 1. However, the ndim of predictions is "
<< pred_sinfo->ndim);
}
K = pred_sinfo->ndim <= 2 ? 0 : pred_sinfo->ndim - 2;
}
if (!tgt_sinfo->IsUnknownNdim()) {
int K_tgt = tgt_sinfo->ndim <= 1 ? 0 : tgt_sinfo->ndim - 1;
if (K != kUnknownNDim && K != K_tgt) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "NLLLoss expects number of dimensions K inferred from different "
"arguments to be equal. However, K from predictions is "
<< K << " while K from targets is " << K_tgt);
}
}
if (wgt_sinfo != nullptr && !wgt_sinfo->IsUnknownNdim() && wgt_sinfo->ndim != 1) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "NLLLoss expects the ndim of weights == 1. However, the ndim of weights is "
<< wgt_sinfo->ndim);
}
arith::Analyzer* analyzer = ctx->GetAnalyzer();
Optional<PrimExpr> N;
Optional<PrimExpr> C;
Array<PrimExpr> output_shape; // N, d1, d2, ..., dk
Optional<Array<PrimExpr>> pred_shape_value;
if (pred_sinfo->shape.defined()) {
pred_shape_value = GetStructInfoAs<ShapeStructInfoNode>(pred_sinfo->shape.value())->values;
}
if (pred_shape_value.defined()) {
if (pred_shape_value.value().size() == 1) {
// (C,)
ICHECK(pred_sinfo->ndim == 1);
C = pred_shape_value.value()[0];
} else {
// (N, C, d1, d2, ..., dk)
ICHECK(pred_shape_value.value().size() >= 2);
ICHECK(pred_sinfo->ndim == static_cast<int>(pred_shape_value.value().size()));
N = pred_shape_value.value()[0];
C = pred_shape_value.value()[1];
output_shape = Array<PrimExpr>();
output_shape.push_back(N.value());
for (size_t i = 2; i < pred_shape_value.value().size(); ++i) {
output_shape.push_back(pred_shape_value.value()[i]);
}
}
}
Optional<Array<PrimExpr>> tgt_shape_value;
if (tgt_sinfo->shape.defined()) {
tgt_shape_value = GetStructInfoAs<ShapeStructInfoNode>(tgt_sinfo->shape.value())->values;
}
if (tgt_shape_value.defined()) {
if (tgt_shape_value.value().empty()) {
// ()
ICHECK(tgt_sinfo->ndim == 0);
if (N.defined()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Shape mismatch for NLLLoss. Predictions shape is "
"(N, C, ...) while targets is a scalar");
}
} else {
// (N,) or (N, d1, d2, ..., dk)
// check N
const PrimExpr& N_tgt = tgt_shape_value.value()[0];
if (N.defined() && analyzer->CanProve(N.value() != N_tgt)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "NLLLoss expects minibatch size N inferred from different "
"arguments to be equal. However, N from predictions is "
<< N << " while N from targets is " << N_tgt);
}
// only C case
if (!N.defined() && C.defined()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Shape mismatch for NLLLoss. Predictions shape is "
"(C,) while targets is not a scalar");
}
if (tgt_shape_value.value().size() == 1) {
// (N,)
ICHECK(tgt_sinfo->IsUnknownNdim() || tgt_sinfo->ndim == 1);
} else {
// (N, d1, d2, ..., dk)
ICHECK(tgt_shape_value.value().size() >= 2);
ICHECK(tgt_sinfo->IsUnknownNdim() ||
tgt_sinfo->ndim == static_cast<int>(tgt_shape_value.value().size()));
if (pred_shape_value.defined()) {
// check (d1, d2, ..., dk)
for (size_t i = 1; i < tgt_shape_value.value().size(); ++i) {
if (analyzer->CanProve(output_shape[i] != tgt_shape_value.value()[i])) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Shape mismatch for NLLLoss. The prediction shape at this dim is "
<< output_shape[i] << " while the target shape at this dim is "
<< tgt_shape_value.value()[i]);
}
}
}
}
}
}
if (wgt_sinfo != nullptr) {
Optional<Array<PrimExpr>> wgt_shape_value;
if (wgt_sinfo->shape.defined()) {
wgt_shape_value = GetStructInfoAs<ShapeStructInfoNode>(wgt_sinfo->shape.value())->values;
}
if (wgt_shape_value.defined()) {
ICHECK(wgt_shape_value.value().size() == 1);
ICHECK(wgt_sinfo->IsUnknownNdim() || wgt_sinfo->ndim == 1);
const PrimExpr& C_wgt = wgt_shape_value.value()[0];
if (C.defined() && analyzer->CanProve(C.value() != C_wgt)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "NLLLoss expects number of classes C inferred from different "
"arguments to be equal. However, C from predictions is "
<< C << " while C from weights is " << C_wgt);
}
}
}
const auto* attrs = call->attrs.as<NLLLossAttrs>();
String reduction = attrs->reduction;
if (reduction == "none") {
// () or (N,) or (N, d1, d2, ..., dk)
if (pred_sinfo->shape.as<ShapeExprNode>()) {
return TensorStructInfo(ShapeExpr(output_shape), output_dtype, vdevice);
} else {
int output_ndim = pred_sinfo->ndim == kUnknownNDim ? kUnknownNDim : pred_sinfo->ndim - 1;
return TensorStructInfo(output_dtype, /*ndim=*/output_ndim, vdevice);
}
} else {
// sum or mean. output is scalar
return TensorStructInfo(/*shape=*/ShapeExpr(Array<PrimExpr>()), output_dtype, vdevice);
}
}