def meta_train()

in ml3/sine_regression_task.py [0:0]


def meta_train(meta_loss_model, meta_optimizer, meta_objective, task_sampler_train, task_sampler_test, exp_cfg):

    num_tasks = exp_cfg['num_train_tasks']
    n_outer_iter= exp_cfg['n_outer_iter']
    inner_lr = exp_cfg['inner_lr']

    results = []

    task_models = []
    task_opts = []
    for i in range(num_tasks):
        task_models.append(SineModel(in_dim=exp_cfg['model']['in_dim'],
                                     hidden_dim=exp_cfg['model']['hidden_dim'],
                                     out_dim=1))
        task_opts.append(torch.optim.SGD(task_models[i].parameters(), lr=inner_lr))

    for outer_i in range(n_outer_iter):
        # Sample a batch of support and query images and labels.

        x_spt, y_spt, x_qry, y_qry = task_sampler_train.sample()

        for i in range(num_tasks):
            task_models[i].reset()

        qry_losses = []
        for _ in range(1):
            pred_losses = []
            meta_optimizer.zero_grad()

            for i in range(num_tasks):
                # zero gradients wrt to meta loss parameters
                with higher.innerloop_ctx(task_models[i], task_opts[i],
                                          copy_initial_weights=False) as (fmodel, diffopt):

                    # update model parameters via meta loss
                    yp = fmodel(x_spt[i])
                    pred_loss = meta_loss_model(yp, y_spt[i])
                    diffopt.step(pred_loss)

                    # compute task loss with new model
                    yp = fmodel(x_spt[i])
                    task_loss = meta_objective(yp, y_spt[i])

                    # this accumulates gradients wrt to meta parameters
                    task_loss.backward()
                    qry_losses.append(task_loss.item())

            meta_optimizer.step()

        avg_qry_loss = sum(qry_losses) / num_tasks
        if outer_i % 10 == 0:
            res_train_eval_reg = eval(task_sampler=task_sampler_train, exp_cfg=exp_cfg,
                                      train_loss_fn=nn.MSELoss(), eval_loss_fn=nn.MSELoss())

            res_train_eval_ml3 = eval(task_sampler=task_sampler_train, exp_cfg=exp_cfg,
                                      train_loss_fn=meta_loss_model, eval_loss_fn=nn.MSELoss())

            res_test_eval_reg = eval(task_sampler=task_sampler_test, exp_cfg=exp_cfg,
                                     train_loss_fn=nn.MSELoss(), eval_loss_fn=nn.MSELoss())

            res_test_eval_ml3 = eval(task_sampler=task_sampler_test, exp_cfg=exp_cfg,
                                     train_loss_fn=meta_loss_model, eval_loss_fn=nn.MSELoss())

            res = {}
            res['train_reg'] = res_train_eval_reg
            res['train_ml3'] = res_train_eval_ml3
            res['test_reg'] = res_test_eval_reg
            res['test_ml3'] = res_test_eval_ml3
            res['task_loss'] = {}
            res['task_loss']['mse'] = qry_losses
            results.append(res)
            test_loss_ml3 = np.mean(res_test_eval_ml3['mse'])
            test_loss_reg = np.mean(res_test_eval_reg['mse'])
            print(
                f'[Epoch {outer_i:.2f}] Train Loss: {avg_qry_loss:.2f}]| Test Loss ML3: {test_loss_ml3:.2f} | TestLoss REG: {test_loss_reg:.2f}'
            )

    return results