def make_random_splits()

in uimnet/datasets/base.py [0:0]


def make_random_splits(indices, props, generator):

  props = [p for p in props if (p > 0.)]
  assert sum(props) == 1
  nsamples = len(indices)

  shuffled = torch.randperm(nsamples, generator=generator)
  splits = []
  #utils.message(indices)
  #utils.message(shuffled)
  ub = 0
  for i, p in enumerate(props):
    lb = 0 if i == 0 else ub
    ub = lb + int(p * nsamples)
    if i == len(props) -1:
      ub = None
    #utils.message(f'{lb}: lb, {ub}: ub')
    #splits.append(indices[shuffled[slice(lb, ub)]])
    splits.append(indices[shuffled[lb:ub]])

  is_consistent = sum(len(s) for s in splits) == nsamples
  assert is_consistent
  return splits