in experiments/codes/experiment/checkpointable_multitask_experiment.py [0:0]
def run_sequential_multitask_training(self):
"""supervised case I: train one model on all the tasks
"""
if self.config.model.should_load_model:
self.load_model()
if self.epoch is None:
self.epoch = 0
# the order is very important here. double check while training
train_world_names = self.config.general.train_rule.split(",")
full_train_world_names = self.gl.get_dataset_names_by_split()["train"]
if self.config.model.should_train:
for train_rule_world in train_world_names:
task_idx = train_world_names.index(train_rule_world)
train_rule_world = full_train_world_names[task_idx]
for epoch in range(self.epoch, self.config.model.num_epochs):
self.logbook.write_message_logs(f"Training rule {train_rule_world}")
# ipdb.set_trace()
self.logbook.write_message_logs(
f"Choosing to train the model " f"on {train_rule_world}"
)
# Train, optimize and test on the same world
train_data = self.dataloaders["train"][train_rule_world]
self.train(train_data, train_rule_world, epoch, task_idx=task_idx)
metrics = self.eval(
{train_rule_world: self.dataloaders["train"][train_rule_world]},
epoch=epoch,
mode="valid",
data_mode="train",
task_idx=task_idx,
)
for sched in self.schedulers:
sched.step(metrics["loss"])
# current task performance
self.eval(
{train_rule_world: self.dataloaders["train"][train_rule_world]},
epoch=epoch,
mode="test",
data_mode="train",
)
if task_idx > 0:
# previous tasks performance
self.eval(
{
task: self.dataloaders["train"][
full_train_world_names[task_idx]
]
for task in train_world_names[:task_idx]
},
epoch=epoch,
mode="test",
data_mode="train_prev",
)
self.periodic_save(task_idx)