in fairseq/data/data_utils.py [0:0]
def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
def check_size(idx):
if isinstance(max_positions, float) or isinstance(max_positions, int):
return size_fn(idx) <= max_positions
elif isinstance(max_positions, dict):
idx_size = size_fn(idx)
assert isinstance(idx_size, dict)
intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
return all(
all(a is None or b is None or a <= b
for a, b in zip(idx_size[key], max_positions[key]))
for key in intersect_keys
)
else:
# logger.info('======= size_fn={}, idx={}, size_fn(idx)={}, max_position={}'.format(
# size_fn, idx, size_fn(idx), max_positions
# ))
# Hacky as heck, for the specific case of multilingual training with RoundRobin.
if isinstance(size_fn(idx), dict) and isinstance(max_positions, tuple):
# added by Feng because size_fn(idx)={'source-target': (14, 32), 'source-untarget': (14, 6)},
# max_position=(1024, 1024)
size_value = list(size_fn(idx).values())
if isinstance(size_value[0], tuple) and len(size_value[0]) == len(max_positions):
flag = True
for sv in size_value:
flag = all(a is None or b is None or a <= b
for a, b in zip(sv, max_positions)
)
if not flag:
return False
return flag
return all(
a is None or b is None or a <= b
for a, b in zip(size_fn(idx).values(), max_positions)
)
# For MultiCorpusSampledDataset, will generalize it later
if not isinstance(size_fn(idx), Iterable):
return all(size_fn(idx) <= b for b in max_positions)
return all(
a is None or b is None or a <= b
for a, b in zip(size_fn(idx), max_positions)
)
ignored = []
itr = collect_filtered(check_size, indices, ignored)
indices = np.fromiter(itr, dtype=np.int64, count=-1)
return indices, ignored