in scripts/train_jat_tokenized.py [0:0]
def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
trust_remote_code=model_args.trust_remote_code,
)
model = JatModel(config)
processor = AutoProcessor.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
trust_remote_code=model_args.trust_remote_code,
)
# Set the tasks
tasks = data_args.tasks
for domain in ["atari", "babyai", "metaworld", "mujoco"]:
if domain in tasks:
tasks.remove(domain)
tasks.extend([env_id for env_id in TASK_NAME_TO_ENV_ID.keys() if env_id.startswith(domain)])
# Load the datasets
if HF_DATASETS_OFFLINE:
for task in tasks:
if not os.path.exists(f"{HF_DATASETS_CACHE}/jat-project/jat-dataset-tokenized/{task}"):
raise ValueError(
f"""Dataset {task} not found in {HF_DATASETS_CACHE}/jat-project/jat-dataset-tokenized/
Make sure to download and save it first with
```
from datasets import load_dataset
dataset = load_dataset('jat-project/jat-dataset-tokenized', '{task}')
dataset.save_to_disk('{HF_DATASETS_CACHE}/jat-project/jat-dataset-tokenized/{task}')
```"""
)
train_dataset = {}
for task in tqdm(tasks, desc="Loading datasets"):
d = load_from_disk(f"{HF_DATASETS_CACHE}/jat-project/jat-dataset-tokenized/{task}")
train_dataset[task] = d["train"]
else:
train_dataset = {
task: load_dataset("jat-project/jat-dataset-tokenized", task, split="train") for task in tasks
}
weights = [SAMPLE_WEIGHTS.get(t, 1.0) for t in train_dataset.keys()]
train_dataset = interleave_datasets(
list(train_dataset.values()),
probabilities=[w / sum(weights) for w in weights],
seed=training_args.seed,
stopping_strategy="all_exhausted",
n_contiguous=training_args.per_device_train_batch_size,
)
# Due to the train dataset's structure, where every 'n' consecutive samples share the same modalities, we can't
# load all samples at once. Different sets of 'n' samples have different modalities. Therefore, we must load and
# process each set of 'n' samples separately.
if training_args.dispatch_batches is not False:
raise ValueError("Make sure to pass `--dispatch_batches False`.")
# Why the training continue after exauhsting the dataset? https://github.com/huggingface/transformers/issues/26635
trainer = MyTrainer(model=model, args=training_args, train_dataset=train_dataset, tokenizer=processor)
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)