def run_sequential_multitask_training()

in experiments/codes/experiment/checkpointable_multitask_experiment.py [0:0]


    def run_sequential_multitask_training(self):
        """supervised case I: train one model on all the tasks
         """
        if self.config.model.should_load_model:
            self.load_model()
        if self.epoch is None:
            self.epoch = 0
        # the order is very important here. double check while training
        train_world_names = self.config.general.train_rule.split(",")
        full_train_world_names = self.gl.get_dataset_names_by_split()["train"]
        if self.config.model.should_train:
            for train_rule_world in train_world_names:
                task_idx = train_world_names.index(train_rule_world)
                train_rule_world = full_train_world_names[task_idx]
                for epoch in range(self.epoch, self.config.model.num_epochs):
                    self.logbook.write_message_logs(f"Training rule {train_rule_world}")
                    # ipdb.set_trace()
                    self.logbook.write_message_logs(
                        f"Choosing to train the model " f"on {train_rule_world}"
                    )
                    # Train, optimize and test on the same world
                    train_data = self.dataloaders["train"][train_rule_world]
                    self.train(train_data, train_rule_world, epoch, task_idx=task_idx)
                    metrics = self.eval(
                        {train_rule_world: self.dataloaders["train"][train_rule_world]},
                        epoch=epoch,
                        mode="valid",
                        data_mode="train",
                        task_idx=task_idx,
                    )
                    for sched in self.schedulers:
                        sched.step(metrics["loss"])

                # current task performance
                self.eval(
                    {train_rule_world: self.dataloaders["train"][train_rule_world]},
                    epoch=epoch,
                    mode="test",
                    data_mode="train",
                )
                if task_idx > 0:
                    # previous tasks performance
                    self.eval(
                        {
                            task: self.dataloaders["train"][
                                full_train_world_names[task_idx]
                            ]
                            for task in train_world_names[:task_idx]
                        },
                        epoch=epoch,
                        mode="test",
                        data_mode="train_prev",
                    )
                self.periodic_save(task_idx)