in ml3/sine_regression_task.py [0:0]
def meta_train(meta_loss_model, meta_optimizer, meta_objective, task_sampler_train, task_sampler_test, exp_cfg):
num_tasks = exp_cfg['num_train_tasks']
n_outer_iter= exp_cfg['n_outer_iter']
inner_lr = exp_cfg['inner_lr']
results = []
task_models = []
task_opts = []
for i in range(num_tasks):
task_models.append(SineModel(in_dim=exp_cfg['model']['in_dim'],
hidden_dim=exp_cfg['model']['hidden_dim'],
out_dim=1))
task_opts.append(torch.optim.SGD(task_models[i].parameters(), lr=inner_lr))
for outer_i in range(n_outer_iter):
# Sample a batch of support and query images and labels.
x_spt, y_spt, x_qry, y_qry = task_sampler_train.sample()
for i in range(num_tasks):
task_models[i].reset()
qry_losses = []
for _ in range(1):
pred_losses = []
meta_optimizer.zero_grad()
for i in range(num_tasks):
# zero gradients wrt to meta loss parameters
with higher.innerloop_ctx(task_models[i], task_opts[i],
copy_initial_weights=False) as (fmodel, diffopt):
# update model parameters via meta loss
yp = fmodel(x_spt[i])
pred_loss = meta_loss_model(yp, y_spt[i])
diffopt.step(pred_loss)
# compute task loss with new model
yp = fmodel(x_spt[i])
task_loss = meta_objective(yp, y_spt[i])
# this accumulates gradients wrt to meta parameters
task_loss.backward()
qry_losses.append(task_loss.item())
meta_optimizer.step()
avg_qry_loss = sum(qry_losses) / num_tasks
if outer_i % 10 == 0:
res_train_eval_reg = eval(task_sampler=task_sampler_train, exp_cfg=exp_cfg,
train_loss_fn=nn.MSELoss(), eval_loss_fn=nn.MSELoss())
res_train_eval_ml3 = eval(task_sampler=task_sampler_train, exp_cfg=exp_cfg,
train_loss_fn=meta_loss_model, eval_loss_fn=nn.MSELoss())
res_test_eval_reg = eval(task_sampler=task_sampler_test, exp_cfg=exp_cfg,
train_loss_fn=nn.MSELoss(), eval_loss_fn=nn.MSELoss())
res_test_eval_ml3 = eval(task_sampler=task_sampler_test, exp_cfg=exp_cfg,
train_loss_fn=meta_loss_model, eval_loss_fn=nn.MSELoss())
res = {}
res['train_reg'] = res_train_eval_reg
res['train_ml3'] = res_train_eval_ml3
res['test_reg'] = res_test_eval_reg
res['test_ml3'] = res_test_eval_ml3
res['task_loss'] = {}
res['task_loss']['mse'] = qry_losses
results.append(res)
test_loss_ml3 = np.mean(res_test_eval_ml3['mse'])
test_loss_reg = np.mean(res_test_eval_reg['mse'])
print(
f'[Epoch {outer_i:.2f}] Train Loss: {avg_qry_loss:.2f}]| Test Loss ML3: {test_loss_ml3:.2f} | TestLoss REG: {test_loss_reg:.2f}'
)
return results