scripts/tokenize_stream.py [116:149]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        for split in dataset_dict[task].keys():
            dataset = dataset_dict[task][split]
            column_names = set(dataset.column_names)  # need to be done here because this info is lost after the map
            dataset = dataset.filter(lambda example: example.get("rewards") != [])

            # Add an initial 0 reward and remove the last reward
            def add_initial_reward(example):
                if "rewards" in example:
                    example["rewards"] = [0.0] + example["rewards"][:-1]
                return example

            dataset = dataset.map(add_initial_reward)

            # We've shown that reducing the sequence length for atari doesn't impact performance but allows for a
            # larger global batch size
            max_length = 64 if task.startswith("atari") else None

            def preprocess(example_batch, max_length):
                return processor(**example_batch, padding="max_length", truncation="preserve", max_length=max_length)

            dataset = dataset.map(
                preprocess,
                batched=True,
                batch_size=1,  # small to avoid OOM
                remove_columns={"text", "images", "text_observations"}.intersection(column_names),
                fn_kwargs={"max_length": max_length},
            )

            def add_loss_weight(example, loss_weight):
                example["loss_weight"] = [loss_weight] * len(next(iter(example.values())))
                return example

            dataset = dataset.map(add_loss_weight, fn_kwargs={"loss_weight": LOSS_WEIGHTS.get(task, 1.0)})
            dataset_dict[task][split] = dataset
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



scripts/train_jat.py [146:179]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        for split in dataset_dict[task].keys():
            dataset = dataset_dict[task][split]
            column_names = set(dataset.column_names)  # need to be done here because this info is lost after the map
            dataset = dataset.filter(lambda example: example.get("rewards") != [])

            # Add an initial 0 reward and remove the last reward
            def add_initial_reward(example):
                if "rewards" in example:
                    example["rewards"] = [0.0] + example["rewards"][:-1]
                return example

            dataset = dataset.map(add_initial_reward)

            # We've shown that reducing the sequence length for atari doesn't impact performance but allows for a
            # larger global batch size
            max_length = 64 if task.startswith("atari") else None

            def preprocess(example_batch, max_length):
                return processor(**example_batch, padding="max_length", truncation="preserve", max_length=max_length)

            dataset = dataset.map(
                preprocess,
                batched=True,
                batch_size=1,  # small to avoid OOM
                remove_columns={"text", "images", "text_observations"}.intersection(column_names),
                fn_kwargs={"max_length": max_length},
            )

            def add_loss_weight(example, loss_weight):
                example["loss_weight"] = [loss_weight] * len(next(iter(example.values())))
                return example

            dataset = dataset.map(add_loss_weight, fn_kwargs={"loss_weight": LOSS_WEIGHTS.get(task, 1.0)})
            dataset_dict[task][split] = dataset
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



