def make_dataset_full()

in modules/SwissArmyTransformer/sat/data_utils/configure_data.py [0:0]


def make_dataset_full(path, split, args, create_dataset_function, 
        dataset_weights=None, random_mapping=True, is_train_data=False, batch_from_same_dataset=False, **kwargs):
    """function to create datasets+tokenizers for common options"""
    print_all('make dataset ' + str(path), level='DEBUG')
    assert isinstance(path, list)

    if (is_train_data and args.iterable_dataset) or (not is_train_data and args.iterable_dataset_eval): # cannot indexed
        # the random mapping is flexible and efficient, but sometimes we have pratical issue
        # For instance, someone just gives you a iterable dataset, e.g. webdataset
        from .webds import ConfiguredResampledShards, DataPipeline
        valid_types = (ConfiguredResampledShards, DataPipeline)
        
        assert split[0] == 1, 'Iterable dataset cannot auto split.'
        ds = []
        for p in path:
            d = create_dataset_function(p, args)
            assert isinstance(d, valid_types)
            ds.append(d)
        # ds = ChainDataset(ds) # please merge them in a url if chain
        if batch_from_same_dataset:
            assert args.num_workers <= 1, 'We cannot control the actual speed of different workers, may mix different iterable parts.'
        ds = AlterDataset(ds, weights=dataset_weights, seed=args.seed, batch_from_same_dataset=batch_from_same_dataset, batch_size=args.batch_size)
        return ds

    if split is None:
        split = [1.] 
    if not should_split(split):
        ds = []
        for p in path:
            d = create_dataset_function(p, args)
            ds.append(d)
        ds = ConcatDataset(ds, weights=dataset_weights)
        if random_mapping:
            if args.epochs is not None: # not auto-scale, but use a given number of epoches.
                ds = RandomDataset(ds, scale=args.epochs, seed=args.seed)
            else:
                world_size = torch.distributed.get_world_size(
                    group=mpu.get_data_parallel_group())
                if is_train_data:
                # only train-dataset will set this to True,
                # so we enlarge it to make sure that the data is sufficient.
                    scale = max(200, 1 + (args.train_iters * args.batch_size * args.gradient_accumulation_steps * world_size) // len(ds))
                else:
                    scale = max(200, 1 + ((1 + args.train_iters // args.eval_interval) * args.eval_iters * args.eval_batch_size * args.gradient_accumulation_steps * world_size) // len(ds))
                ds = RandomMappingDataset(ds, scale=scale)
        return ds 
    else:
        # must first split datasets, then reweight/concat, finally random-mapping.
        # this order avoids overlapping.
        train_ds, valid_ds, test_ds = [], [], []
        for p in path:
            d = create_dataset_function(p, args)
            if should_split(split):
                dtrain, dvalid, dtest = split_ds(d, split, block_size=args.block_size, seed=args.seed)
                train_ds.append(dtrain)
                valid_ds.append(dvalid)
                test_ds.append(dtest)
        train_ds = ConcatDataset(train_ds, weights=dataset_weights)
        valid_ds = ConcatDataset(valid_ds, weights=dataset_weights)
        test_ds = ConcatDataset(test_ds, weights=dataset_weights)
        if random_mapping:
            world_size = torch.distributed.get_world_size(
                group=mpu.get_data_parallel_group())
            scale = max(200, 1 + (args.train_iters * args.batch_size * world_size) // len(train_ds))
            train_ds = RandomMappingDataset(train_ds, scale=scale)
            scale = max(200, 1 + ((1 + args.train_iters // args.eval_interval) * args.eval_iters * args.eval_batch_size * args.gradient_accumulation_steps * world_size) // len(valid_ds))
            valid_ds = RandomMappingDataset(valid_ds, scale=scale)
            test_ds = RandomMappingDataset(test_ds)
        return train_ds, valid_ds, test_ds