def _for_each_instance_rewrite()

in torch_xla/utils/utils.py [0:0]


def _for_each_instance_rewrite(value, select_fn, fn, rwmap):
  rvalue = rwmap.get(id(value), None)
  if rvalue is not None:
    return rvalue
  result = value
  if select_fn(value):
    result = fn(value)
    rwmap[id(value)] = result
  elif isinstance(value, dict):
    result = dict()
    rwmap[id(value)] = result
    for k, v in value.items():
      k = _for_each_instance_rewrite(k, select_fn, fn, rwmap)
      result[k] = _for_each_instance_rewrite(v, select_fn, fn, rwmap)
  elif isinstance(value, set):
    result = set()
    rwmap[id(value)] = result
    for x in value:
      result.add(_for_each_instance_rewrite(x, select_fn, fn, rwmap))
  elif isinstance(value, (list, tuple)):
    # We transform tuples to lists here, as we need to set the object mapping
    # before calling into the recursion. This code might break if user code
    # expects a tuple.
    result = list()
    rwmap[id(value)] = result
    for x in value:
      result.append(_for_each_instance_rewrite(x, select_fn, fn, rwmap))
  elif isinstance(value, DataWrapper):
    new_tensors = []
    for x in value.get_tensors():
      new_tensors.append(_for_each_instance_rewrite(x, select_fn, fn, rwmap))
    result = value.from_tensors(new_tensors)
    rwmap[id(value)] = result
  elif hasattr(value, '__dict__'):
    result = copy.copy(value)
    rwmap[id(value)] = result
    for k in result.__dict__.keys():
      v = _for_each_instance_rewrite(result.__dict__[k], select_fn, fn, rwmap)
      result.__dict__[k] = v
  else:
    rwmap[id(value)] = result
  return result