def validate_neuropod_config()

in source/python/neuropod/utils/config_utils.py [0:0]


def validate_neuropod_config(config):
    """
    Validates a neuropod config
    """
    name = config["name"]
    platform = config["platform"]
    device_mapping = config["input_tensor_device"]

    if not isinstance(name, string_types):
        raise ValueError(
            "Field 'name' in config must be a string! Got value {} of type {}.".format(
                name, type(name)
            )
        )

    if not isinstance(platform, string_types):
        raise ValueError(
            "Field 'platform' in config must be a string! Got value {} of type {}.".format(
                platform, type(platform)
            )
        )

    validate_tensor_spec(config["input_spec"])
    validate_tensor_spec(config["output_spec"])

    # Optional custom ops
    if "custom_ops" in config:
        custom_ops = config["custom_ops"]

        if not isinstance(custom_ops, list):
            raise ValueError(
                "Optional field 'custom_ops' must be a list! Got value {} of type {}".format(
                    custom_ops, type(custom_ops)
                )
            )

        for op in custom_ops:
            if not isinstance(op, string_types):
                raise ValueError(
                    "All items in 'custom_ops' must be strings! Got value {} of type {}.".format(
                        op, type(op)
                    )
                )

    # Ensure all inputs have a device specified
    input_tensor_names = {item["name"] for item in config["input_spec"]}
    device_tensor_names = set(device_mapping.keys())
    inputs_without_device = input_tensor_names - device_tensor_names
    devices_without_input = device_tensor_names - input_tensor_names

    if len(inputs_without_device) != 0:
        raise ValueError(
            "Some input tensors do not have devices specified: {}".format(
                inputs_without_device
            )
        )

    if len(devices_without_input) != 0:
        raise ValueError(
            "Devices were specified for some tensors not in the `input_spec`: {}".format(
                devices_without_input
            )
        )

    for tensor_name, device in device_mapping.items():
        if device not in ["GPU", "CPU"]:
            raise ValueError(
                "Device must either be 'GPU' or 'CPU'! Got value '{}' for tensor named '{}'.".format(
                    device, tensor_name
                )
            )