def run_single_task()

in experiments/codes/experiment/checkpointable_multitask_experiment.py [0:0]


    def run_single_task(self, world_mode="train"):
        """
        Only run one task - Supervised setup
        :return:
        """
        if self.epoch is None:
            self.epoch = 0
        train_world_names = list(self.dataloaders[world_mode].keys())
        wn = [w.split("/")[-1] for w in train_world_names]
        wn_i = wn.index(self.config.general.train_rule)
        train_rule_world = train_world_names[wn_i]
        task_idx = train_world_names.index(train_rule_world)
        if self.config.model.should_train:
            for epoch in range(self.epoch, self.config.model.num_epochs):
                self.logbook.write_message_logs(f"Training rule {train_rule_world}")

                # ipdb.set_trace()
                self.logbook.write_message_logs(
                    f"Choosing to train the model " f"on {train_rule_world}"
                )

                train_data = self.dataloaders[world_mode][train_rule_world]
                self.train(train_data, train_rule_world, epoch, task_idx=task_idx)
                self.epoch = epoch
                self.periodic_save(epoch=epoch)
                metrics = self.eval(
                    {train_rule_world: train_data},
                    epoch=epoch,
                    mode="valid",
                    data_mode=world_mode,
                    task_idx=task_idx,
                )
                for sched in self.schedulers:
                    sched.step(metrics["loss"])
                self.eval(
                    {train_rule_world: self.dataloaders[world_mode][train_rule_world]},
                    epoch=epoch,
                    mode="test",
                    data_mode=world_mode,
                    task_idx=task_idx,
                )
                if self.config.logger.watch_model:
                    norms = [w.norm().item() for w in self.model.weights]
                    norm_metric = {
                        wn: norms[wi] for wi, wn in enumerate(self.model.weight_names)
                    }
                    norm_metric["mode"] = "train"
                    norm_metric["minibatch"] = self.train_step
                    self.logbook.write_metric_logs(norm_metric)
                self.periodic_save(epoch)