def index_select()

in graphlearn_torch/python/utils/common.py [0:0]


def index_select(data, index):
  if data is None:
    return None
  if isinstance(data, dict):
    new_data = {}
    for k, v in data.items():
      new_data[k] = index_select(v, index)
    return new_data
  if isinstance(data, list):
    new_data = []
    for v in data:
      new_data.append(index_select(v, index))
    return new_data
  if isinstance(data, tuple):
    return tuple(index_select(list(data), index))
  if isinstance(index, tuple):
    start, end = index
    return data[start:end]
  return data[index]