def __do_model_step()

in compare_models.py [0:0]


    def __do_model_step(self, mi, scmap, race, inp, targ, game_name, optimize=True):
        '''
        Single bptt step of the model
        '''
        model = self.models[mi]
        has_z = hasattr(model, "with_z") and model.with_z
        model.hidden = model.repackage_hidden(model.hidden)
        if optimize:
            for _, optimizer in self.optimizers[mi].items():
                optimizer.zero_grad()

        if has_z and optimize:
            old_hiddens = model.hidden
            should_retain = hasattr(model, "z_pred_cut_gradient") and not model.z_pred_cut_gradient

            if model.zbwd_init_zfwd:
                if not model.zbwd_single:
                    model.zbwd = Variable(th.zeros(inp.size(0),1,model.zsize).type(th.cuda.FloatTensor))
                model.zbwd_initialized = False

            # disable gradient updates to the rest of the model
            for param in model.parameters():
                param.requires_grad = False
            model.zbwd.requires_grad = True

            # TODO remove this multiple optimizer stuff since we just create a new one here
            self.optimizers[mi]['zbwd'] = model.z_opt([model.zbwd], lr=model.z_lr)
            if not hasattr(model, "tail_values"):
                model.tail_values = {'zbwd_norm': deque(maxlen=200), 'zbwd_grad_norm': deque(maxlen=200), 'zpred_norm': deque(maxlen=200), 'zpred_loss': deque(maxlen=200), 'zpred_grad_norm': deque(maxlen=200)}
            z_step = 0
            # Do a bunch of model forwards to calculate the gradient wrt z_bwd
            if model.zfwd_zbwd_ratio > 0 and random.random() > model.zfwd_zbwd_ratio:
                model.hidden = model.repackage_hidden(old_hiddens)
                input, embed = model.trunk_encode_pool(scmap, race, inp)
                while z_step == 0 or model.zbwd_to_convergence and z_step < 5*np.log10(self.update_number+1) and model.z_lr * model.zbwd.grad.norm().data[0] > 1e-4:
                    self.optimizers[mi]['zbwd'].zero_grad()
                    model.hidden = model.repackage_hidden(old_hiddens)
                    output = model.forward_rest(input, embed)
                    loss = self.loss_fns[mi](inp, output, targ)
                    if self.args.loss_averaging:
                        loss = loss / targ.size(0)
                    loss.backward()
                    if self.update_number % 2000 == 0:
                        logging.log(42, "    Lf_zbwd_grad {} zbwd_norm {}".format(model.zbwd.grad.norm().data[0], model.zbwd.norm().data[0]))
                    self.optimizers[mi]['zbwd'].step()
                    z_step += 1
                model.tail_values['zbwd_norm'].append(model.zbwd.norm().data[0])
                model.tail_values['zbwd_grad_norm'].append(model.zbwd.grad.norm().data[0])
            if self.update_number % 2000 == 0:
                logging.log(42, "AVG_Lf_zbwd_grad {} zbwd_norm {}".format(avg(model.tail_values['zbwd_grad_norm']), avg(model.tail_values['zbwd_norm'])))
                logging.log(42, "Lf_zbwd_grad {}, zbwd_norm {}".format(model.zbwd.grad.norm().data[0], model.zbwd.norm().data[0]))

            # Finally, update the whole model
            for param in model.parameters():
                param.requires_grad = True

            model.hidden = model.repackage_hidden(old_hiddens)
            output = model(scmap, race, inp)
            loss = self.loss_fns[mi](inp, output, targ)
            if self.args.loss_averaging:
                loss = loss / targ.size(0)


            if model.zbwd_init_zfwd:
                zvar = F.tanh(model.zbwd).detach()
                if model.zbwd_single:
                    zvar = zvar.expand(inp.size(0), 1, model.zsize)
            else:
                zvar = Variable(model.zbwd.data).expand((inp.size(0), 1, model.zsize))
            zloss = model.z_lambda * model.zlossfn(model.zfwd, zvar)
            zloss.backward(retain_graph=True)
            zpred_grad_norm = next(model.zpred.parameters()).grad.norm().data[0]
            zpred_norm = next(model.zpred.parameters()).norm().data[0]
            if self.update_number % 2000 < 50:
                logging.log(42, "loss {}".format(loss.data[0]))
                logging.log(42, "Lz_zpred_loss {} zpred_grad_norm {} zpred_params_norm {}".format(zloss.data[0], zpred_grad_norm , zpred_norm))
            model.tail_values['zpred_norm'].append(zpred_norm)
            model.tail_values['zpred_grad_norm'].append(zpred_grad_norm)
            model.tail_values['zpred_loss'].append(zloss.data[0])
            if self.update_number % 2000 == 0:
                logging.log(42, "AVG_Lz_zpred_loss {} zpred_grad_norm {}, zpred_params_norm {}".format(avg(model.tail_values['zpred_loss']), avg(model.tail_values['zpred_grad_norm']), avg(model.tail_values['zpred_norm'])))
        else:
            output = model(scmap, race, inp)
            loss = self.loss_fns[mi](inp, output, targ)
            if self.args.loss_averaging:
                loss = loss / targ.size(0)

        if optimize:
            loss.backward()
            if self.args.clip > 0:
                th.nn.utils.clip_grad_norm(model.parameters(), self.args.clip)
            self.optimizers[mi]['model'].step()

        return [x.detach() for x in output], loss.detach()