in compare_models.py [0:0]
def _do_model_step(self, scmap, race, inp, targ, game_name, optimize=True):
''' Returns:
([list of outputs for each model], [list of losses for each model])
'''
self.update_number += 1
for model in self.models:
if hasattr(model, "init_z"):
model.init_z(game_name)
model.hidden = model.init_hidden()
losses = []
outputs = []
for mi, model in enumerate(self.models):
if self.args.bptt > 0 and model.accepts_bptt:
cur_losses = []
cur_outputs = []
for start, end in self.__generate_bptt_slices(inp.size(0)):
output, loss = self.__do_model_step(mi, scmap, race, inp[start:end], targ[start:end], game_name)
loss *= end-start if self.args.loss_averaging else 1
cur_losses.append(loss.data.cpu())
cur_outputs.append([x.data.cpu() for x in output])
# Need to transpose last_outs and concat across time dim
outputs.append([th.cat(x) for x in zip(*cur_outputs)])
losses.append(th.cat(cur_losses).sum())
if self.args.loss_averaging:
losses[mi] /= outputs[mi].size(0)
elif self.args.bptt == 0 or not model.accepts_bptt:
output, loss = self.__do_model_step(mi, scmap, race, inp, targ, game_name, optimize)
outputs.append([x.data.cpu() for x in output])
assert(loss.numel() == 1)
losses.append(loss.data.cpu().sum())
else:
# TODO bptt sampled with cut ~ Bernouilli(hyperparam)
self.quit_training("BPTT < 0 not implemented")
return outputs, losses