in train/compute/python/lib/pytorch/data_impl.py [0:0]
def materialize_arg(arg: Dict[str, Any], device: str) -> Any:
"""
Given an arg configuration, materialize the test data for that arg.
"""
def create_tensor(attr: Dict[str, Any]):
shape = attr["shape"]
requires_grad = attr.get("requires_grad", True)
if len(shape) > 0:
if attr["dtype"] == "float" or attr["dtype"] == "double":
return torch.rand(
*shape, requires_grad=requires_grad, device=torch.device(device)
)
elif attr["dtype"] == "int" or attr["dtype"] == "long":
return torch.randint(
-10,
10,
tuple(shape),
requires_grad=requires_grad,
device=torch.device(device),
)
# Single value
else:
return torch.tensor(
random.uniform(-10.0, 10.0),
dtype=pytorch_dtype_map[attr["dtype"]],
requires_grad=requires_grad,
device=torch.device(device),
)
def create_float(attr: Dict[str, Any]):
if "value" in attr:
return attr["value"]
return random.uniform(attr["value_range"][0], attr["value_range"][1])
def create_int(attr: Dict[str, Any]):
# check "value" key exists, attr["value"] = 0 could be eval to False
if "value" in attr:
return attr["value"]
return random.randint(attr["value_range"][0], attr["value_range"][1])
def create_str(attr: Dict[str, Any]):
# check "value" key exists, attr["value"] = 0 could be eval to False
if "value" in attr:
return attr["value"]
return ""
def create_bool(attr: Dict[str, Any]):
return attr["value"]
def create_none(attr: Dict[str, Any]):
return None
def create_device(attr: Dict[str, Any]):
return torch.device(attr["value"])
def create_genericlist(attr: List[Any]):
result = []
for item in attr["value"]:
result.append(arg_factory[item["type"]](item))
return result
def create_tuple(attr: List[Any]):
result = create_genericlist(attr)
return tuple(result)
# Map of argument types to the create methods.
arg_factory: Dict[str, Callable] = {
"tensor": create_tensor,
"float": create_float,
"double": create_float,
"int": create_int,
"long": create_int,
"none": create_none,
"bool": create_bool,
"device": create_device,
"str": create_str,
"genericlist": create_genericlist,
"tuple": create_tuple,
}
return arg_factory[arg["type"]](arg)