in torch_xla/utils/utils.py [0:0]
def _for_each_instance_rewrite(value, select_fn, fn, rwmap):
rvalue = rwmap.get(id(value), None)
if rvalue is not None:
return rvalue
result = value
if select_fn(value):
result = fn(value)
rwmap[id(value)] = result
elif isinstance(value, dict):
result = dict()
rwmap[id(value)] = result
for k, v in value.items():
k = _for_each_instance_rewrite(k, select_fn, fn, rwmap)
result[k] = _for_each_instance_rewrite(v, select_fn, fn, rwmap)
elif isinstance(value, set):
result = set()
rwmap[id(value)] = result
for x in value:
result.add(_for_each_instance_rewrite(x, select_fn, fn, rwmap))
elif isinstance(value, (list, tuple)):
# We transform tuples to lists here, as we need to set the object mapping
# before calling into the recursion. This code might break if user code
# expects a tuple.
result = list()
rwmap[id(value)] = result
for x in value:
result.append(_for_each_instance_rewrite(x, select_fn, fn, rwmap))
elif isinstance(value, DataWrapper):
new_tensors = []
for x in value.get_tensors():
new_tensors.append(_for_each_instance_rewrite(x, select_fn, fn, rwmap))
result = value.from_tensors(new_tensors)
rwmap[id(value)] = result
elif hasattr(value, '__dict__'):
result = copy.copy(value)
rwmap[id(value)] = result
for k in result.__dict__.keys():
v = _for_each_instance_rewrite(result.__dict__[k], select_fn, fn, rwmap)
result.__dict__[k] = v
else:
rwmap[id(value)] = result
return result