def __init__()

in compare_models.py [0:0]


    def __init__(self, args, models=None):
        loss_fns = []
        self.update_number = 0
        for model in models:
            logging.info("Model: {}".format(model))
            loss_fns.append(self.__mk_loss_fn(args))
            # TODO a structured prediction (local and or global) type of loss?
            # TODO a mass conservation (except where we have factories?) loss?
            # TODO the loss/model should make you pay less for small (dx, dy) in prediction

            if args.check_nan:
                register_nan_checks(model)

        def xfer(tensor):
            if args.gpu >= 0:
                return tensor.cuda(args.gpu)
            return tensor

        models = [xfer(model) for model in models]

        self.args = args
        self.xfer = xfer
        self.models = models
        self.loss_fns = loss_fns
        self.train_dl = self.__get_dataloader()
        self.other_dl = self.__get_dataloader(both=False)
        self.featurizer = args.featurizer
        self.optimizers = []

        self.valid_every = args.valid_every
        self.plot_loss_every = 100
        self.n_input_timesteps = args.n_input_timesteps
        self.train_loss_pane = None

        self.save_timer = time.time()

        if args.load is not False and path.exists(args.load) and \
           path.exists(path.join(args.load, models[0].model_name + ".pth")):
            logging.info("Loading model from {}".format(args.load))
            self.load(args, models, args.load)
        if args.finetune != "" and path.exists(args.finetune) and \
           path.exists(path.join(args.finetune, models[0].model_name + ".pth")):
            logging.info("finetuneing model from {}".format(args.finetune))
            self.load(args, models, args.finetune)
        else:
            logging.info("No previous model found, initting new models")
            self.init(args, models)

        for model in models:
            nparam = sum([param.data.numel() for param in model.parameters()])
            logging.log(42, "n_param {} {}".format(model.model_name, nparam))

        args.featurizer = None  # can't pickle