scripts/tokenize_stream.py (108 lines of code) (raw):
#!/usr/bin/env python3
"""Train a JAT model on the JAT dataset"""
import os
import sys
from dataclasses import dataclass, field
from functools import partial
from typing import List, Optional
import datasets.config
from datasets import Dataset, DatasetDict, load_dataset
from transformers import AutoProcessor, HfArgumentParser
from jat.eval.rl.core import TASK_NAME_TO_ENV_ID
# Sometimes, the server is down; increasing the number of
# retries allows to wait more instead of making the training crash
datasets.config.STREAMING_READ_MAX_RETRIES = 10000
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config we are going to train from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
"should only be set to `True` for repositories you trust and in which you have read the code, as it "
"will execute code present on the Hub on your local machine."
)
},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
tasks: List[str] = field(default_factory=list, metadata={"help": "Tasks to train on."})
preprocess_num_proc: int = field(
default=1, metadata={"help": "Number of processes to use for preprocessing the data."}
)
eval_num_samples: int = field(default=1000, metadata={"help": "Number of samples to use for evaluation."})
LOSS_WEIGHTS = {
**{task: 10.0 for task in TASK_NAME_TO_ENV_ID.keys() if task.startswith("mujoco")},
**{task: 50.0 for task in TASK_NAME_TO_ENV_ID.keys() if task.startswith("metaworld")},
"mujoco-pendulum": 50.0,
"mujoco-doublependulum": 20.0,
}
def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args = parser.parse_args_into_dataclasses()
processor = AutoProcessor.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
trust_remote_code=model_args.trust_remote_code,
)
# Set the tasks
tasks = data_args.tasks
for domain in ["atari", "babyai", "metaworld", "mujoco"]:
if domain in tasks:
tasks.remove(domain)
tasks.extend([env_id for env_id in TASK_NAME_TO_ENV_ID.keys() if env_id.startswith(domain)])
# Load the dataset
dataset_dict = {}
for task in tasks:
dataset = load_dataset("jat-project/jat-dataset", task, streaming=True)
if task == "oscar":
dataset = DatasetDict({"train": dataset["train"].take(1_000_000), "test": dataset["test"].take(1_000)})
dataset_dict[task] = dataset
def gen_from_iterable_dataset(iterable_ds):
yield from iterable_ds
configs = datasets.get_dataset_config_names("jat-project/jat-dataset-tokenized")
for task in dataset_dict.keys():
if task in configs:
print(f"Task {task} already processed, skipping...")
continue
else:
print(f"Task {task} not processed yet, processing...")
task_dataset = {}
for split in dataset_dict[task].keys():
dataset = dataset_dict[task][split]
column_names = set(dataset.column_names) # need to be done here because this info is lost after the map
dataset = dataset.filter(lambda example: example.get("rewards") != [])
# Add an initial 0 reward and remove the last reward
def add_initial_reward(example):
if "rewards" in example:
example["rewards"] = [0.0] + example["rewards"][:-1]
return example
dataset = dataset.map(add_initial_reward)
# We've shown that reducing the sequence length for atari doesn't impact performance but allows for a
# larger global batch size
max_length = 64 if task.startswith("atari") else None
def preprocess(example_batch, max_length):
return processor(**example_batch, padding="max_length", truncation="preserve", max_length=max_length)
dataset = dataset.map(
preprocess,
batched=True,
batch_size=1, # small to avoid OOM
remove_columns={"text", "images", "text_observations"}.intersection(column_names),
fn_kwargs={"max_length": max_length},
)
def add_loss_weight(example, loss_weight):
example["loss_weight"] = [loss_weight] * len(next(iter(example.values())))
return example
dataset = dataset.map(add_loss_weight, fn_kwargs={"loss_weight": LOSS_WEIGHTS.get(task, 1.0)})
dataset_dict[task][split] = dataset
print(f"Generated examples for {task}/{split}")
task_dataset[split] = Dataset.from_generator(partial(gen_from_iterable_dataset, dataset))
task_dataset = DatasetDict(task_dataset)
print(f"Pushing {task} to the hub...")
task_dataset.push_to_hub("jat-project/jat-dataset-tokenized", config_name=task)
if __name__ == "__main__":
main()
# python scripts/tokenize_.py --model_name_or_path jat-project/jat-small --tasks mujoco --trust_remote_code