def _filter_by_size_dynamic()

in fairseq/data/data_utils.py [0:0]


def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
    def check_size(idx):
        if isinstance(max_positions, float) or isinstance(max_positions, int):
            return size_fn(idx) <= max_positions
        elif isinstance(max_positions, dict):
            idx_size = size_fn(idx)
            assert isinstance(idx_size, dict)
            intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
            return all(
                all(a is None or b is None or a <= b
                    for a, b in zip(idx_size[key], max_positions[key]))
                for key in intersect_keys
            )
        else:
            # logger.info('======= size_fn={}, idx={}, size_fn(idx)={}, max_position={}'.format(
            #     size_fn, idx, size_fn(idx), max_positions
            # ))
            # Hacky as heck, for the specific case of multilingual training with RoundRobin.
            if isinstance(size_fn(idx), dict) and isinstance(max_positions, tuple):
                # added by Feng because size_fn(idx)={'source-target': (14, 32), 'source-untarget': (14, 6)},
                # max_position=(1024, 1024)
                size_value = list(size_fn(idx).values())
                if isinstance(size_value[0], tuple) and len(size_value[0]) == len(max_positions):
                    flag = True
                    for sv in size_value:
                        flag = all(a is None or b is None or a <= b
                                   for a, b in zip(sv, max_positions)
                                   )
                        if not flag:
                            return False
                    return flag
                return all(
                    a is None or b is None or a <= b
                    for a, b in zip(size_fn(idx).values(), max_positions)
                )
            # For MultiCorpusSampledDataset, will generalize it later
            if not isinstance(size_fn(idx), Iterable):
                return all(size_fn(idx) <= b for b in max_positions)
            return all(
                a is None or b is None or a <= b
                for a, b in zip(size_fn(idx), max_positions)
            )
    ignored = []
    itr = collect_filtered(check_size, indices, ignored)
    indices = np.fromiter(itr, dtype=np.int64, count=-1)
    return indices, ignored