std::unique_ptr TorchNeuropodBackend::infer_internal()

in source/neuropod/backends/torchscript/torch_backend.cc [295:489]


std::unique_ptr<NeuropodValueMap> TorchNeuropodBackend::infer_internal(const NeuropodValueMap &inputs)
{
    torch::NoGradGuard guard;

    // Get inference schema
    const auto &method    = model_->get_method("forward");
    const auto &schema    = SCHEMA(method);
    const auto &arguments = schema.arguments();

    // Whether or not this model expects a dictionary as an input
    bool is_dict_input = false;

    // Torch 1.2.0 adds a ClassType argument to every model
    bool has_class_type = false;

    // Torch 1.10.2 adds UnionType support in TorchScript
    bool dict_value_is_union_type = false;

#if CAFFE2_NIGHTLY_VERSION >= 20190717
    if (arguments.size() > 0 && arguments.at(0).type()->kind() == c10::TypeKind::ClassType)
    {
        has_class_type = true;
    }
#endif

    if (arguments.size() == 2 && has_class_type && arguments.at(1).type()->kind() == c10::TypeKind::DictType)
    {
        is_dict_input = true;
    }

    if (arguments.size() == 1 && arguments.at(0).type()->kind() == c10::TypeKind::DictType)
    {
        is_dict_input = true;
    }

#if CAFFE2_NIGHTLY_VERSION >= 20220127
    if (is_dict_input && arguments.at(has_class_type ? 1 : 0).type()->cast<torch::DictType>()->getValueType()->kind() ==
                             c10::TypeKind::UnionType)
    {
        dict_value_is_union_type = true;
    }
#endif

    // Define the vector of inputs and add the inputs
    std::vector<torch::jit::IValue> torch_inputs(arguments.size() - (has_class_type ? 1 : 0));
    if (is_dict_input && !dict_value_is_union_type)
    {
        // This model expects a dict as input
        MAKE_DICT(tensor_input_dict, torch::Tensor);
        MAKE_DICT(str_input_dict, torch::List<std::string>);

        for (const auto &entry : inputs)
        {
            const auto &value = get_ivalue_from_torch_tensor(entry.second);

            if (value.isTensor())
            {
                DICT_INSERT(tensor_input_dict, entry.first, value.toTensor());
            }
            else
            {
#if CAFFE2_NIGHTLY_VERSION >= 20200421
                DICT_INSERT(str_input_dict, entry.first, c10::impl::toTypedList<std::string>(value.toList()));
#elif CAFFE2_NIGHTLY_VERSION >= 20190717
                DICT_INSERT(str_input_dict, entry.first, c10::impl::toTypedList<std::string>(value.toGenericList()));
#else
                DICT_INSERT(str_input_dict, entry.first, value);
#endif
            }
        }

        // TODO(vip): This assumes a model only takes in string "tensors" or tensors, but not both
        // Refactor to add support for both and add documentation
        if (*arguments.at(has_class_type ? 1 : 0).type()->cast<torch::DictType>()->getValueType() ==
            *torch::TensorType::get())
        {
            torch_inputs.at(0) = tensor_input_dict;
        }
        else
        {
            torch_inputs.at(0) = str_input_dict;
        }
    }
#if CAFFE2_NIGHTLY_VERSION >= 20220127
    // In Torch 1.10.2, TorchScript introduced UnionType and it now supports Dict[str, Union[List[str], torch.Tensor]]
    // as model input type. We would like to support this input type in neuropod torchscript backend
    else if (is_dict_input && dict_value_is_union_type)
    {
        const auto &value_type_ptr = torch::UnionType::create({torch::ListType::ofStrings(), torch::TensorType::get()});
        c10::impl::GenericDict input_dict(torch::StringType::get(), value_type_ptr);

        for (const auto &entry : inputs)
        {
            const auto &value = get_ivalue_from_torch_tensor(entry.second);
            input_dict.insert(entry.first, value);
        }
        torch_inputs.at(0) = input_dict;
    }
#endif
    else
    {
        // Pass inputs normally
        for (const auto &entry : inputs)
        {
            const auto  input_name = entry.first;
            const auto &input_data = get_ivalue_from_torch_tensor(entry.second);

            const auto arg_index = schema.argumentIndexWithName(input_name);
            if (!arg_index.has_value())
            {
                NEUROPOD_ERROR(
                    "An tensor named '{}' was provided, but does not exist in the input schema of the "
                    "TorchScript model. Please ensure your model expects an input with that name. Schema: {}",
                    input_name,
                    schema);
            }

            torch_inputs.at(static_cast<size_t>(arg_index.value() - (has_class_type ? 1 : 0))) = input_data;
        }
    }

    // Run inference
    c10::IValue result = model_->forward(torch_inputs);

    // Get outputs
    auto to_return = stdx::make_unique<NeuropodValueMap>();

    if (result.isGenericDict())
    {
        process_dict(*to_return, result);
    }
#if CAFFE2_NIGHTLY_VERSION >= 20200421
    else if (result.isTensor() || result.isList())
#else
    else if (result.isTensor() || result.isGenericList())
#endif
    {
        if (output_specs_.empty())
        {
            NEUROPOD_ERROR("Model did not return dict and output spec is empty");
        }
        if (output_specs_.size() != 1)
        {
            NEUROPOD_ERROR("Model did not return dict and output spec is not size 1");
        }

        auto &name        = output_specs_[0].name;
        auto &tensor_type = output_specs_[0].type;
        insert_value_in_output(*to_return, name, result, true, tensor_type);
    }
    else if (result.isTuple())
    {
        auto  tuple = result.toTuple();
        auto &elems = tuple->elements();

        // Macros to handle namedtuples (Torch >= 1.3.0)
#if CAFFE2_NIGHTLY_VERSION >= 20191010
        const auto tuple_type     = result.type()->cast<torch::TupleType>();
        const bool is_named_tuple = tuple_type && tuple_type->schema();
#define GET_NAME(i) tuple_type->schema()->arguments()[i].name()
#else
        const bool is_named_tuple = false;
#define GET_NAME(i) ""
#endif
        if (is_named_tuple)
        {
            // This is a named tuple
            // NOLINTNEXTLINE(modernize-loop-convert): Can't always use a range based loop here
            for (size_t i = 0; i < elems.size(); i++)
            {
                insert_value_in_output(*to_return, GET_NAME(i), elems.at(i));
            }
        }
        else
        {
            // Each item in this tuple should be a dict
            for (const auto &item : elems)
            {
                if (item.isGenericDict())
                {
                    process_dict(*to_return, item);
                }
                else
                {
                    NEUROPOD_ERROR("When returning a tuple, each item must be a dict. Got {}", item.tagKind());
                }
            }
        }

#undef GET_NAME
    }
    else { NEUROPOD_ERROR("Torchscript model output type not supported in neuropod"); }

    return to_return;
}