scripts/train_jat.py (140 lines of code) (raw):
#!/usr/bin/env python3
"""Train a JAT model on the JAT dataset"""
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import List, Optional
import datasets.config
from datasets import load_dataset, load_from_disk
from datasets.config import HF_DATASETS_CACHE, HF_DATASETS_OFFLINE
from transformers import AutoConfig, AutoProcessor, HfArgumentParser, Trainer, TrainingArguments
from jat.eval.rl.core import TASK_NAME_TO_ENV_ID
from jat.modeling_jat import JatModel
from jat.utils_interleave_datasets import interleave_datasets
# Sometimes, the server is down; increasing the number of
# retries allows to wait more instead of making the training crash
datasets.config.STREAMING_READ_MAX_RETRIES = 10000
logger = logging.getLogger(__name__)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config we are going to train from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
"should only be set to `True` for repositories you trust and in which you have read the code, as it "
"will execute code present on the Hub on your local machine."
)
},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
tasks: List[str] = field(default_factory=list, metadata={"help": "Tasks to train on."})
preprocess_num_proc: int = field(
default=1, metadata={"help": "Number of processes to use for preprocessing the data."}
)
eval_num_samples: int = field(default=1000, metadata={"help": "Number of samples to use for evaluation."})
LOSS_WEIGHTS = {
**{task: 10.0 for task in TASK_NAME_TO_ENV_ID.keys() if task.startswith("mujoco")},
**{task: 50.0 for task in TASK_NAME_TO_ENV_ID.keys() if task.startswith("metaworld")},
"mujoco-pendulum": 50.0,
"mujoco-doublependulum": 20.0,
}
SAMPLE_WEIGHTS = {
"conceptual-captions": 10.0,
"oscar": 10.0,
"wikipedia": 10.0,
}
os.environ["WANDB_ENTITY"] = "jat-project"
os.environ["WANDB_PROJECT"] = "jat"
def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
trust_remote_code=model_args.trust_remote_code,
)
model = JatModel(config)
processor = AutoProcessor.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
trust_remote_code=model_args.trust_remote_code,
)
# Set the tasks
tasks = data_args.tasks
for domain in ["atari", "babyai", "metaworld", "mujoco"]:
if domain in tasks:
tasks.remove(domain)
tasks.extend([env_id for env_id in TASK_NAME_TO_ENV_ID.keys() if env_id.startswith(domain)])
# Load the dataset
# Automatic cache is broken for parquet datasets
# The following is a fix from https://github.com/huggingface/datasets/issues/3547#issuecomment-1252503988
dataset_dict = {}
if HF_DATASETS_OFFLINE:
for task in tasks:
if not os.path.exists(f"{HF_DATASETS_CACHE}/jat-project/jat-dataset/{task}"):
raise ValueError(
f"""Dataset {task} not found in {HF_DATASETS_CACHE}/jat-project/jat-dataset/
Make sure to download and save it first with
```
from datasets import load_dataset
dataset = load_dataset('jat-project/jat-dataset', '{task}')
dataset.save_to_disk('{HF_DATASETS_CACHE}/jat-project/jat-dataset/{task}')
```"""
)
dataset = load_from_disk(f"{HF_DATASETS_CACHE}/jat-project/jat-dataset/{task}")
dataset_dict[task] = {s: d.to_iterable_dataset() for s, d in dataset.items()}
else:
for task in tasks:
dataset_dict[task] = load_dataset("jat-project/jat-dataset", task, streaming=True)
# Preprocess the dataset
for task in dataset_dict.keys():
for split in dataset_dict[task].keys():
dataset = dataset_dict[task][split]
column_names = set(dataset.column_names) # need to be done here because this info is lost after the map
dataset = dataset.filter(lambda example: example.get("rewards") != [])
# Add an initial 0 reward and remove the last reward
def add_initial_reward(example):
if "rewards" in example:
example["rewards"] = [0.0] + example["rewards"][:-1]
return example
dataset = dataset.map(add_initial_reward)
# We've shown that reducing the sequence length for atari doesn't impact performance but allows for a
# larger global batch size
max_length = 64 if task.startswith("atari") else None
def preprocess(example_batch, max_length):
return processor(**example_batch, padding="max_length", truncation="preserve", max_length=max_length)
dataset = dataset.map(
preprocess,
batched=True,
batch_size=1, # small to avoid OOM
remove_columns={"text", "images", "text_observations"}.intersection(column_names),
fn_kwargs={"max_length": max_length},
)
def add_loss_weight(example, loss_weight):
example["loss_weight"] = [loss_weight] * len(next(iter(example.values())))
return example
dataset = dataset.map(add_loss_weight, fn_kwargs={"loss_weight": LOSS_WEIGHTS.get(task, 1.0)})
dataset_dict[task][split] = dataset
train_dataset = {t: d["train"] for t, d in dataset_dict.items()}
eval_dataset = {t: d["test"] for t, d in dataset_dict.items()}
for key in tasks: # Reduce the number of eval samples
eval_dataset[key] = eval_dataset[key].take(data_args.eval_num_samples)
weights = [SAMPLE_WEIGHTS.get(t, 1.0) for t in train_dataset.keys()]
train_dataset = interleave_datasets(
list(train_dataset.values()),
probabilities=[w / sum(weights) for w in weights],
seed=training_args.seed,
stopping_strategy="all_exhausted",
n_contiguous=training_args.per_device_train_batch_size,
)
# Due to the train dataset's structure, where every 'n' consecutive samples share the same modalities, we can't
# load all samples at once. Different sets of 'n' samples have different modalities. Therefore, we must load and
# process each set of 'n' samples separately.
if training_args.dispatch_batches is not False:
raise ValueError("Make sure to pass `--dispatch_batches False`.")
# Why the training continue after exauhsting the dataset? https://github.com/huggingface/transformers/issues/26635
trainer = Trainer(
model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=processor
)
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
if __name__ == "__main__":
main()