def meta_train_shaped_sine()

in ml3/ml3_train.py [0:0]


def meta_train_shaped_sine(n_outer_iter,shaped,num_task,n_inner_iter,sine_model,ml3_loss,task_loss_fn, exp_folder):
    theta_ranges = []
    landscape_with_extra = []
    landscape_mse = []

    meta_opt = torch.optim.Adam(ml3_loss.parameters(), lr=ml3_loss.learning_rate)

    for outer_i in range(n_outer_iter):
        # set gradient with respect to meta loss parameters to 0
        batch_inputs, batch_labels, batch_thetas = generate_sinusoid_batch(num_task, 64, n_inner_iter)
        for task in range(num_task):
            sine_model = ShapedSineModel()
            inner_opt = torch.optim.SGD([sine_model.freq], lr=sine_model.learning_rate)
            for step in range(n_inner_iter):
                inputs = torch.Tensor(batch_inputs[task, step, :])
                labels = torch.Tensor(batch_labels[task, step, :])
                label_thetas = torch.Tensor(batch_thetas[task, step, :])

                ''' Updating the frequency parameters, taking gradient of theta wrt meta loss '''
                with higher.innerloop_ctx(sine_model, inner_opt) as (fmodel, diffopt):
                    # use current meta loss to update model
                    yp = fmodel(inputs)
                    meta_input = torch.cat([inputs, yp, labels], dim=1)

                    meta_out = ml3_loss(meta_input)
                    loss = meta_out.mean()
                    diffopt.step(loss)

                    yp = fmodel(inputs)
                    task_loss = task_loss_fn(inputs, yp, labels, shaped, fmodel.freq, label_thetas)

                sine_model.freq = torch.nn.Parameter(fmodel.freq.clone().detach())
                inner_opt = torch.optim.SGD([sine_model.freq], lr=sine_model.learning_rate)

                ''' updating the learned loss '''
                meta_opt.zero_grad()
                task_loss.mean().backward()
                meta_opt.step()

        if outer_i % 100 == 0:
            print("task loss: {}".format(task_loss.mean().item()))

        torch.save(ml3_loss.state_dict(), f'{exp_folder}/ml3_loss_shaped_sine_' + str(shaped) + '.pt')

        if outer_i%10==0:
            t_range, l_with_extra, l_mse = plot_loss(shaped, exp_folder)
            theta_ranges.append(t_range)
            landscape_with_extra.append(l_with_extra)
            landscape_mse.append(l_mse)
    np.save(f'{exp_folder}/theta_ranges_'+str(shaped)+'_.npy', theta_ranges)
    np.save(f'{exp_folder}/landscape_with_extra_'+str(shaped)+'_.npy',landscape_with_extra)
    np.save(f'{exp_folder}/landscape_mse_'+str(shaped)+'_.npy',landscape_mse)