in compare_models.py [0:0]
def __init__(self, args, models=None):
loss_fns = []
self.update_number = 0
for model in models:
logging.info("Model: {}".format(model))
loss_fns.append(self.__mk_loss_fn(args))
# TODO a structured prediction (local and or global) type of loss?
# TODO a mass conservation (except where we have factories?) loss?
# TODO the loss/model should make you pay less for small (dx, dy) in prediction
if args.check_nan:
register_nan_checks(model)
def xfer(tensor):
if args.gpu >= 0:
return tensor.cuda(args.gpu)
return tensor
models = [xfer(model) for model in models]
self.args = args
self.xfer = xfer
self.models = models
self.loss_fns = loss_fns
self.train_dl = self.__get_dataloader()
self.other_dl = self.__get_dataloader(both=False)
self.featurizer = args.featurizer
self.optimizers = []
self.valid_every = args.valid_every
self.plot_loss_every = 100
self.n_input_timesteps = args.n_input_timesteps
self.train_loss_pane = None
self.save_timer = time.time()
if args.load is not False and path.exists(args.load) and \
path.exists(path.join(args.load, models[0].model_name + ".pth")):
logging.info("Loading model from {}".format(args.load))
self.load(args, models, args.load)
if args.finetune != "" and path.exists(args.finetune) and \
path.exists(path.join(args.finetune, models[0].model_name + ".pth")):
logging.info("finetuneing model from {}".format(args.finetune))
self.load(args, models, args.finetune)
else:
logging.info("No previous model found, initting new models")
self.init(args, models)
for model in models:
nparam = sum([param.data.numel() for param in model.parameters()])
logging.log(42, "n_param {} {}".format(model.model_name, nparam))
args.featurizer = None # can't pickle