def run_multitask_training_unique_composition()

in experiments/codes/experiment/checkpointable_multitask_experiment.py [0:0]


    def run_multitask_training_unique_composition(self):
        """Run multitask training with unique composition functions
        One Representation Function, and individual composition functions for all tasks
         """
        if self.config.model.should_load_model:
            self.load_model()
        # setup 100 composition functions in the main memory
        num_worlds = len(list(self.dataloaders["train"].keys()))
        composition_world_cache = {
            ni: copy.deepcopy(self.composition_fn) for ni in range(num_worlds)
        }
        for ci, cm in composition_world_cache.items():
            cm.to("cpu")
        torch.cuda.empty_cache()
        optim_store_location = os.path.join(self.config.model.save_dir, "opts")
        os.makedirs(optim_store_location)
        # make all optimizers
        representation_optimizer = Net.register_params_to_optimizer(
            self.config,
            self.representation_fn.model.weights,
            is_signature=self.is_signature,
        )
        if self.epoch is None:
            self.epoch = 0
        if self.config.model.should_train:
            for epoch in range(self.epoch, self.config.model.num_epochs):
                train_world_names = list(self.dataloaders["train"].keys())
                train_rule_world = random.choice(train_world_names)
                task_idx = train_world_names.index(train_rule_world)
                self.logbook.write_message_logs(f"Training rule {train_rule_world}")

                self.logbook.write_message_logs(
                    f"Choosing to train the model " f"on {train_rule_world}"
                )
                # Train, optimize and test on the same world
                train_data = self.dataloaders["train"][train_rule_world]
                # set the correct composition function
                composition_world_cache[task_idx].to(self.config.general.device)
                self.composition_fn = composition_world_cache[task_idx]
                composition_optimizer = Net.register_params_to_optimizer(
                    self.config,
                    self.composition_fn.model.weights,
                    is_signature=self.is_signature,
                )
                optim_store_file = os.path.join(
                    optim_store_location, "{}_opt.pt".format(task_idx)
                )
                if os.path.exists(optim_store_file):
                    composition_optimizer.load_state_dict(torch.load(optim_store_file))
                self.optimizers = [representation_optimizer, composition_optimizer]
                self.train(train_data, train_rule_world, epoch, task_idx=task_idx)
                metrics = self.eval(
                    {train_rule_world: self.dataloaders["train"][train_rule_world]},
                    epoch=epoch,
                    mode="valid",
                    data_mode="train",
                    task_idx=task_idx,
                )
                for sched in self.schedulers:
                    sched.step(metrics["loss"])

                composition_world_cache[task_idx].to("cpu")
                torch.save(self.optimizers[-1].state_dict(), optim_store_file)
                torch.cuda.empty_cache()
                self.periodic_save(epoch)