in experiments/codes/experiment/checkpointable_multitask_experiment.py [0:0]
def run(self):
"""Method to run the experiment"""
# experiment.run()
if self.config.model.use_composition_fn:
self.load_only_composition()
if self.config.model.use_representation_fn:
self.load_only_representation()
if self.config.model.freeze_composition_fn:
self.load_only_composition()
self.composition_fn.freeze_weights()
if self.config.model.freeze_representation_fn:
self.representation_fn.freeze_weights()
# re-register the params to the optimizer
self.register_optim_sched(
skip_composition_registry=self.config.model.freeze_composition_fn,
skip_representation_registry=self.config.model.freeze_representation_fn,
)
if self.config.general.train_mode == "run_mult":
self.run_multitask_training()
elif self.config.general.train_mode == "run_mult_unique_comp":
self.run_multitask_training_unique_composition()
elif self.config.general.train_mode == "run_mult_unique_rep":
self.run_multitask_training_unique_representation()
elif self.config.general.train_mode == "supervised":
self.run_single_task(world_mode="train")
elif self.config.general.train_mode == "supervised_valid":
self.run_single_task(world_mode="valid")
elif self.config.general.train_mode == "supervised_test":
self.run_single_task(world_mode="test")
elif self.config.general.train_mode == "seq_mult":
self.run_sequential_multitask_training()
elif self.config.general.train_mode == "seq_mult_comp":
self.run_sequential_multitask_unique_composition()
elif self.config.general.train_mode == "seq_mult_rep":
self.run_sequential_multitask_unique_representation()
elif self.config.general.train_mode == "seq_zero":
self.run_sequential_zeroshot_transfer()
elif self.config.general.train_mode == "seq_full":
self.run_sequential_fewshot_transfer(full_shot=True)
elif self.config.general.train_mode == "seq_few":
self.run_sequential_fewshot_transfer()
elif self.config.general.train_mode == "pretrain":
self.run_pretraining()
else:
raise NotImplementedError(
"training mode not implemented. should be either one of \n supervised / seq_mult / seq_zero / seq_full / seq_few"
)