def abstract_model_xy()

in model.py [0:0]


def abstract_model_xy(sess, hps, feeds, train_iterator, test_iterator, data_init, lr, f_loss):

    # == Create class with static fields and methods
    class m(object):
        pass
    m.sess = sess
    m.feeds = feeds
    m.lr = lr

    # === Loss and optimizer
    loss_train, stats_train = f_loss(train_iterator, True)
    all_params = tf.trainable_variables()
    if hps.gradient_checkpointing == 1:
        from memory_saving_gradients import gradients
        gs = gradients(loss_train, all_params)
    else:
        gs = tf.gradients(loss_train, all_params)

    optimizer = {'adam': optim.adam, 'adamax': optim.adamax,
                 'adam2': optim.adam2}[hps.optimizer]

    train_op, polyak_swap_op, ema = optimizer(
        all_params, gs, alpha=lr, hps=hps)
    if hps.direct_iterator:
        m.train = lambda _lr: sess.run([train_op, stats_train], {lr: _lr})[1]
    else:
        def _train(_lr):
            _x, _y = train_iterator()
            return sess.run([train_op, stats_train], {feeds['x']: _x,
                                                      feeds['y']: _y, lr: _lr})[1]
        m.train = _train

    m.polyak_swap = lambda: sess.run(polyak_swap_op)

    # === Testing
    loss_test, stats_test = f_loss(test_iterator, False, reuse=True)
    if hps.direct_iterator:
        m.test = lambda: sess.run(stats_test)
    else:
        def _test():
            _x, _y = test_iterator()
            return sess.run(stats_test, {feeds['x']: _x,
                                         feeds['y']: _y})
        m.test = _test

    # === Saving and restoring
    saver = tf.train.Saver()
    saver_ema = tf.train.Saver(ema.variables_to_restore())
    m.save_ema = lambda path: saver_ema.save(
        sess, path, write_meta_graph=False)
    m.save = lambda path: saver.save(sess, path, write_meta_graph=False)
    m.restore = lambda path: saver.restore(sess, path)

    # === Initialize the parameters
    if hps.restore_path != '':
        m.restore(hps.restore_path)
    else:
        with Z.arg_scope([Z.get_variable_ddi, Z.actnorm], init=True):
            results_init = f_loss(None, True, reuse=True)
        sess.run(tf.global_variables_initializer())
        sess.run(results_init, {feeds['x']: data_init['x'],
                                feeds['y']: data_init['y']})
    sess.run(hvd.broadcast_global_variables(0))

    return m