in ml3/ml3_train.py [0:0]
def meta_train_shaped_sine(n_outer_iter,shaped,num_task,n_inner_iter,sine_model,ml3_loss,task_loss_fn, exp_folder):
theta_ranges = []
landscape_with_extra = []
landscape_mse = []
meta_opt = torch.optim.Adam(ml3_loss.parameters(), lr=ml3_loss.learning_rate)
for outer_i in range(n_outer_iter):
# set gradient with respect to meta loss parameters to 0
batch_inputs, batch_labels, batch_thetas = generate_sinusoid_batch(num_task, 64, n_inner_iter)
for task in range(num_task):
sine_model = ShapedSineModel()
inner_opt = torch.optim.SGD([sine_model.freq], lr=sine_model.learning_rate)
for step in range(n_inner_iter):
inputs = torch.Tensor(batch_inputs[task, step, :])
labels = torch.Tensor(batch_labels[task, step, :])
label_thetas = torch.Tensor(batch_thetas[task, step, :])
''' Updating the frequency parameters, taking gradient of theta wrt meta loss '''
with higher.innerloop_ctx(sine_model, inner_opt) as (fmodel, diffopt):
# use current meta loss to update model
yp = fmodel(inputs)
meta_input = torch.cat([inputs, yp, labels], dim=1)
meta_out = ml3_loss(meta_input)
loss = meta_out.mean()
diffopt.step(loss)
yp = fmodel(inputs)
task_loss = task_loss_fn(inputs, yp, labels, shaped, fmodel.freq, label_thetas)
sine_model.freq = torch.nn.Parameter(fmodel.freq.clone().detach())
inner_opt = torch.optim.SGD([sine_model.freq], lr=sine_model.learning_rate)
''' updating the learned loss '''
meta_opt.zero_grad()
task_loss.mean().backward()
meta_opt.step()
if outer_i % 100 == 0:
print("task loss: {}".format(task_loss.mean().item()))
torch.save(ml3_loss.state_dict(), f'{exp_folder}/ml3_loss_shaped_sine_' + str(shaped) + '.pt')
if outer_i%10==0:
t_range, l_with_extra, l_mse = plot_loss(shaped, exp_folder)
theta_ranges.append(t_range)
landscape_with_extra.append(l_with_extra)
landscape_mse.append(l_mse)
np.save(f'{exp_folder}/theta_ranges_'+str(shaped)+'_.npy', theta_ranges)
np.save(f'{exp_folder}/landscape_with_extra_'+str(shaped)+'_.npy',landscape_with_extra)
np.save(f'{exp_folder}/landscape_mse_'+str(shaped)+'_.npy',landscape_mse)