in experiments/codes/experiment/checkpointable_multitask_experiment.py [0:0]
def run_multitask_training_unique_representation(self):
"""Run multitask training with unique representation functions
One composition function, and individual representation functions for each task
"""
if self.config.model.should_load_model:
self.load_model()
# setup 100 representation functions in the main memory
num_worlds = len(list(self.dataloaders["train"].keys()))
representation_world_cache = {
ni: copy.deepcopy(self.representation_fn) for ni in range(num_worlds)
}
for ri, rm in representation_world_cache.items():
rm.to("cpu")
torch.cuda.empty_cache()
# make all optimizers
representation_world_optimizer_states = {}
composition_optimizer = Net.register_params_to_optimizer(
self.config,
self.composition_fn.model.weights,
is_signature=self.is_signature,
)
if self.epoch is None:
self.epoch = 0
if self.config.model.should_train:
for epoch in range(self.epoch, self.config.model.num_epochs):
train_world_names = list(self.dataloaders["train"].keys())
train_rule_world = random.choice(train_world_names)
task_idx = train_world_names.index(train_rule_world)
self.logbook.write_message_logs(f"Training rule {train_rule_world}")
# ipdb.set_trace()
self.logbook.write_message_logs(
f"Choosing to train the model " f"on {train_rule_world}"
)
# Train, optimize and test on the same world
train_data = self.dataloaders["train"][train_rule_world]
# set the correct representation function
representation_world_cache[task_idx].to(self.config.general.device)
self.representation_fn = representation_world_cache[task_idx]
representation_optimizer = Net.register_params_to_optimizer(
self.config,
self.representation_fn.model.weights,
is_signature=self.is_signature,
)
if task_idx in representation_world_optimizer_states:
representation_optimizer.load_state_dict(
representation_world_optimizer_states[task_idx]
)
self.optimizers = [representation_optimizer, composition_optimizer]
self.train(train_data, train_rule_world, epoch, task_idx=task_idx)
metrics = self.eval(
{train_rule_world: self.dataloaders["train"][train_rule_world]},
epoch=epoch,
mode="valid",
data_mode="train",
task_idx=task_idx,
)
for sched in self.schedulers:
sched.step(metrics["loss"])
representation_world_cache[task_idx].to("cpu")
representation_world_optimizer_states[task_idx] = self.optimizers[
0
].state_dict()
self.periodic_save(epoch)