in train/compute/python/lib/iterator.py [0:0]
def create_range_iter(arg: Dict[str, Any]):
def create_tensor(attr: Dict[str, Any]):
logger.debug(f"{attr}")
result = copy.copy(attr)
# if ranges exists, create iterator
if ATTR_RANGE in attr:
ranges = set(attr[ATTR_RANGE])
for key, val in attr.items():
if key in ranges:
result[key] = arg_factory_iter[key](val)
else:
result[key] = val
return TableProduct(result)
# otherwise return unchanged
return result
def create_float(attr: Dict[str, Any]):
# Not supporting range float values, any use cases?
return copy.copy(attr)
def create_int(attr: Dict[str, Any]):
result = copy.copy(attr)
if ATTR_RANGE in attr:
ranges = set(attr[ATTR_RANGE])
if "value" in ranges:
result["value"] = full_range(*attr["value"])
return TableProduct(result)
return result
def create_str(attr: Dict[str, Any]):
result = copy.copy(attr)
if ATTR_RANGE in attr:
ranges = set(attr[ATTR_RANGE])
if "value" in ranges:
result["value"] = IterableList(attr["value"])
return TableProduct(result)
return result
def create_bool(attr: Dict[str, Any]):
result = copy.copy(attr)
if ATTR_RANGE in attr:
ranges = set(attr[ATTR_RANGE])
if "value" in ranges:
result["value"] = IterableList(attr["value"])
return TableProduct(result)
return result
def create_none(attr: Dict[str, Any]):
return copy.copy(attr)
# Called for a list of data types to be iterated
def create_dtype(values: List[str]):
return IterableList(values)
def create_shape(values: List[Any]):
shape = []
for val in values:
# TODO lofe: should also check for ATTR_RANGE
if type(val) is list:
shape.append(full_range(*val))
else:
shape.append(val)
return ListProduct(shape)
def create_device(attr: Dict[str, Any]):
result = copy.copy(attr)
if ATTR_RANGE in attr:
ranges = set(attr[ATTR_RANGE])
if "value" in ranges:
result["value"] = IterableList(attr["value"])
return TableProduct(result)
return result
def create_genericlist(attr: List[Any]):
result = copy.copy(attr)
if ATTR_RANGE in attr:
ranges = set(attr[ATTR_RANGE])
if "value" in ranges:
values = []
for item in attr["value"]:
values.append(arg_factory_iter[item["type"]](item))
result["value"] = ListProduct(values)
return TableProduct(result)
return result
def create_tuple(attr: List[Any]):
result = copy.copy(attr)
if ATTR_RANGE in attr:
ranges = set(attr[ATTR_RANGE])
if "value" in ranges:
values = []
for item in attr["value"]:
values.append(arg_factory_iter[item["type"]](item))
result["value"] = ListProduct(values)
return TableProduct(result)
return result
arg_factory_iter: Dict[str, Callable] = {
"tensor": create_tensor,
"float": create_float,
"double": create_float,
"int": create_int,
"long": create_int,
"str": create_str,
"none": create_none,
"bool": create_bool,
"dtype": create_dtype,
"shape": create_shape,
"device": create_device,
"genericlist": create_genericlist,
"tuple": create_tuple,
}
return arg_factory_iter[arg["type"]](arg)