def materialize_arg()

in train/compute/python/lib/pytorch/data_impl.py [0:0]


def materialize_arg(arg: Dict[str, Any], device: str) -> Any:
    """
    Given an arg configuration, materialize the test data for that arg.
    """

    def create_tensor(attr: Dict[str, Any]):
        shape = attr["shape"]
        requires_grad = attr.get("requires_grad", True)
        if len(shape) > 0:
            if attr["dtype"] == "float" or attr["dtype"] == "double":
                return torch.rand(
                    *shape, requires_grad=requires_grad, device=torch.device(device)
                )
            elif attr["dtype"] == "int" or attr["dtype"] == "long":
                return torch.randint(
                    -10,
                    10,
                    tuple(shape),
                    requires_grad=requires_grad,
                    device=torch.device(device),
                )
        # Single value
        else:
            return torch.tensor(
                random.uniform(-10.0, 10.0),
                dtype=pytorch_dtype_map[attr["dtype"]],
                requires_grad=requires_grad,
                device=torch.device(device),
            )

    def create_float(attr: Dict[str, Any]):
        if "value" in attr:
            return attr["value"]
        return random.uniform(attr["value_range"][0], attr["value_range"][1])

    def create_int(attr: Dict[str, Any]):
        # check "value" key exists, attr["value"] = 0 could be eval to False
        if "value" in attr:
            return attr["value"]
        return random.randint(attr["value_range"][0], attr["value_range"][1])

    def create_str(attr: Dict[str, Any]):
        # check "value" key exists, attr["value"] = 0 could be eval to False
        if "value" in attr:
            return attr["value"]
        return ""

    def create_bool(attr: Dict[str, Any]):
        return attr["value"]

    def create_none(attr: Dict[str, Any]):
        return None

    def create_device(attr: Dict[str, Any]):
        return torch.device(attr["value"])

    def create_genericlist(attr: List[Any]):
        result = []
        for item in attr["value"]:
            result.append(arg_factory[item["type"]](item))
        return result

    def create_tuple(attr: List[Any]):
        result = create_genericlist(attr)
        return tuple(result)

    # Map of argument types to the create methods.
    arg_factory: Dict[str, Callable] = {
        "tensor": create_tensor,
        "float": create_float,
        "double": create_float,
        "int": create_int,
        "long": create_int,
        "none": create_none,
        "bool": create_bool,
        "device": create_device,
        "str": create_str,
        "genericlist": create_genericlist,
        "tuple": create_tuple,
    }
    return arg_factory[arg["type"]](arg)