def main()

in ml3/sine_regression_task.py [0:0]


def main(exp_cfg):
    seed = exp_cfg['seed']
    num_train_tasks = exp_cfg['num_train_tasks']
    num_test_tasks = exp_cfg['num_test_tasks']
    outer_lr = exp_cfg['outer_lr']

    np.random.seed(seed)
    torch.manual_seed(seed)

    meta_loss_model = ML3_SineRegressionLoss(in_dim=exp_cfg['metaloss']['in_dim'],
                                             hidden_dim=exp_cfg['metaloss']['hidden_dim'])

    meta_optimizer = torch.optim.Adam(meta_loss_model.parameters(), lr=outer_lr)

    meta_objective = nn.MSELoss()

    task_sampler_train = SineTaskSampler(num_tasks_total=num_train_tasks, num_tasks_per_batch=num_train_tasks, num_data_points=100,
                                               amp_range=[1.0, 1.0],
                                               input_range=[-2.0, 2.0],
    )

    task_sampler_test = SineTaskSampler(num_tasks_total=num_test_tasks, num_tasks_per_batch=num_test_tasks, num_data_points=100,
                                              input_range=[-5.0, 5.0],
                                              amp_range=[0.2, 5.0],
                                              phase_range=[-np.pi, np.pi]
                                              )
#
    res = meta_train(meta_loss_model=meta_loss_model, meta_optimizer=meta_optimizer, meta_objective=meta_objective,
                     task_sampler_train=task_sampler_train, task_sampler_test=task_sampler_test,
                     exp_cfg=exp_cfg)

    data_file = os.path.join(exp_cfg['log_dir'], exp_cfg['exp_log_file_name'])

    data_dir = os.path.dirname(data_file)
    if data_dir is not '' and not os.path.exists(data_dir):  # Create directory if it doesn't exist.
        os.makedirs(data_dir)
    torch.save(res, data_file)