in scripts/tokenize_stream.py [0:0]
def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args = parser.parse_args_into_dataclasses()
processor = AutoProcessor.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
trust_remote_code=model_args.trust_remote_code,
)
# Set the tasks
tasks = data_args.tasks
for domain in ["atari", "babyai", "metaworld", "mujoco"]:
if domain in tasks:
tasks.remove(domain)
tasks.extend([env_id for env_id in TASK_NAME_TO_ENV_ID.keys() if env_id.startswith(domain)])
# Load the dataset
dataset_dict = {}
for task in tasks:
dataset = load_dataset("jat-project/jat-dataset", task, streaming=True)
if task == "oscar":
dataset = DatasetDict({"train": dataset["train"].take(1_000_000), "test": dataset["test"].take(1_000)})
dataset_dict[task] = dataset
def gen_from_iterable_dataset(iterable_ds):
yield from iterable_ds
configs = datasets.get_dataset_config_names("jat-project/jat-dataset-tokenized")
for task in dataset_dict.keys():
if task in configs:
print(f"Task {task} already processed, skipping...")
continue
else:
print(f"Task {task} not processed yet, processing...")
task_dataset = {}
for split in dataset_dict[task].keys():
dataset = dataset_dict[task][split]
column_names = set(dataset.column_names) # need to be done here because this info is lost after the map
dataset = dataset.filter(lambda example: example.get("rewards") != [])
# Add an initial 0 reward and remove the last reward
def add_initial_reward(example):
if "rewards" in example:
example["rewards"] = [0.0] + example["rewards"][:-1]
return example
dataset = dataset.map(add_initial_reward)
# We've shown that reducing the sequence length for atari doesn't impact performance but allows for a
# larger global batch size
max_length = 64 if task.startswith("atari") else None
def preprocess(example_batch, max_length):
return processor(**example_batch, padding="max_length", truncation="preserve", max_length=max_length)
dataset = dataset.map(
preprocess,
batched=True,
batch_size=1, # small to avoid OOM
remove_columns={"text", "images", "text_observations"}.intersection(column_names),
fn_kwargs={"max_length": max_length},
)
def add_loss_weight(example, loss_weight):
example["loss_weight"] = [loss_weight] * len(next(iter(example.values())))
return example
dataset = dataset.map(add_loss_weight, fn_kwargs={"loss_weight": LOSS_WEIGHTS.get(task, 1.0)})
dataset_dict[task][split] = dataset
print(f"Generated examples for {task}/{split}")
task_dataset[split] = Dataset.from_generator(partial(gen_from_iterable_dataset, dataset))
task_dataset = DatasetDict(task_dataset)
print(f"Pushing {task} to the hub...")
task_dataset.push_to_hub("jat-project/jat-dataset-tokenized", config_name=task)