in compare_models.py [0:0]
def __do_model_step(self, mi, scmap, race, inp, targ, game_name, optimize=True):
'''
Single bptt step of the model
'''
model = self.models[mi]
has_z = hasattr(model, "with_z") and model.with_z
model.hidden = model.repackage_hidden(model.hidden)
if optimize:
for _, optimizer in self.optimizers[mi].items():
optimizer.zero_grad()
if has_z and optimize:
old_hiddens = model.hidden
should_retain = hasattr(model, "z_pred_cut_gradient") and not model.z_pred_cut_gradient
if model.zbwd_init_zfwd:
if not model.zbwd_single:
model.zbwd = Variable(th.zeros(inp.size(0),1,model.zsize).type(th.cuda.FloatTensor))
model.zbwd_initialized = False
# disable gradient updates to the rest of the model
for param in model.parameters():
param.requires_grad = False
model.zbwd.requires_grad = True
# TODO remove this multiple optimizer stuff since we just create a new one here
self.optimizers[mi]['zbwd'] = model.z_opt([model.zbwd], lr=model.z_lr)
if not hasattr(model, "tail_values"):
model.tail_values = {'zbwd_norm': deque(maxlen=200), 'zbwd_grad_norm': deque(maxlen=200), 'zpred_norm': deque(maxlen=200), 'zpred_loss': deque(maxlen=200), 'zpred_grad_norm': deque(maxlen=200)}
z_step = 0
# Do a bunch of model forwards to calculate the gradient wrt z_bwd
if model.zfwd_zbwd_ratio > 0 and random.random() > model.zfwd_zbwd_ratio:
model.hidden = model.repackage_hidden(old_hiddens)
input, embed = model.trunk_encode_pool(scmap, race, inp)
while z_step == 0 or model.zbwd_to_convergence and z_step < 5*np.log10(self.update_number+1) and model.z_lr * model.zbwd.grad.norm().data[0] > 1e-4:
self.optimizers[mi]['zbwd'].zero_grad()
model.hidden = model.repackage_hidden(old_hiddens)
output = model.forward_rest(input, embed)
loss = self.loss_fns[mi](inp, output, targ)
if self.args.loss_averaging:
loss = loss / targ.size(0)
loss.backward()
if self.update_number % 2000 == 0:
logging.log(42, " Lf_zbwd_grad {} zbwd_norm {}".format(model.zbwd.grad.norm().data[0], model.zbwd.norm().data[0]))
self.optimizers[mi]['zbwd'].step()
z_step += 1
model.tail_values['zbwd_norm'].append(model.zbwd.norm().data[0])
model.tail_values['zbwd_grad_norm'].append(model.zbwd.grad.norm().data[0])
if self.update_number % 2000 == 0:
logging.log(42, "AVG_Lf_zbwd_grad {} zbwd_norm {}".format(avg(model.tail_values['zbwd_grad_norm']), avg(model.tail_values['zbwd_norm'])))
logging.log(42, "Lf_zbwd_grad {}, zbwd_norm {}".format(model.zbwd.grad.norm().data[0], model.zbwd.norm().data[0]))
# Finally, update the whole model
for param in model.parameters():
param.requires_grad = True
model.hidden = model.repackage_hidden(old_hiddens)
output = model(scmap, race, inp)
loss = self.loss_fns[mi](inp, output, targ)
if self.args.loss_averaging:
loss = loss / targ.size(0)
if model.zbwd_init_zfwd:
zvar = F.tanh(model.zbwd).detach()
if model.zbwd_single:
zvar = zvar.expand(inp.size(0), 1, model.zsize)
else:
zvar = Variable(model.zbwd.data).expand((inp.size(0), 1, model.zsize))
zloss = model.z_lambda * model.zlossfn(model.zfwd, zvar)
zloss.backward(retain_graph=True)
zpred_grad_norm = next(model.zpred.parameters()).grad.norm().data[0]
zpred_norm = next(model.zpred.parameters()).norm().data[0]
if self.update_number % 2000 < 50:
logging.log(42, "loss {}".format(loss.data[0]))
logging.log(42, "Lz_zpred_loss {} zpred_grad_norm {} zpred_params_norm {}".format(zloss.data[0], zpred_grad_norm , zpred_norm))
model.tail_values['zpred_norm'].append(zpred_norm)
model.tail_values['zpred_grad_norm'].append(zpred_grad_norm)
model.tail_values['zpred_loss'].append(zloss.data[0])
if self.update_number % 2000 == 0:
logging.log(42, "AVG_Lz_zpred_loss {} zpred_grad_norm {}, zpred_params_norm {}".format(avg(model.tail_values['zpred_loss']), avg(model.tail_values['zpred_grad_norm']), avg(model.tail_values['zpred_norm'])))
else:
output = model(scmap, race, inp)
loss = self.loss_fns[mi](inp, output, targ)
if self.args.loss_averaging:
loss = loss / targ.size(0)
if optimize:
loss.backward()
if self.args.clip > 0:
th.nn.utils.clip_grad_norm(model.parameters(), self.args.clip)
self.optimizers[mi]['model'].step()
return [x.detach() for x in output], loss.detach()