in model.py [0:0]
def abstract_model_xy(sess, hps, feeds, train_iterator, test_iterator, data_init, lr, f_loss):
# == Create class with static fields and methods
class m(object):
pass
m.sess = sess
m.feeds = feeds
m.lr = lr
# === Loss and optimizer
loss_train, stats_train = f_loss(train_iterator, True)
all_params = tf.trainable_variables()
if hps.gradient_checkpointing == 1:
from memory_saving_gradients import gradients
gs = gradients(loss_train, all_params)
else:
gs = tf.gradients(loss_train, all_params)
optimizer = {'adam': optim.adam, 'adamax': optim.adamax,
'adam2': optim.adam2}[hps.optimizer]
train_op, polyak_swap_op, ema = optimizer(
all_params, gs, alpha=lr, hps=hps)
if hps.direct_iterator:
m.train = lambda _lr: sess.run([train_op, stats_train], {lr: _lr})[1]
else:
def _train(_lr):
_x, _y = train_iterator()
return sess.run([train_op, stats_train], {feeds['x']: _x,
feeds['y']: _y, lr: _lr})[1]
m.train = _train
m.polyak_swap = lambda: sess.run(polyak_swap_op)
# === Testing
loss_test, stats_test = f_loss(test_iterator, False, reuse=True)
if hps.direct_iterator:
m.test = lambda: sess.run(stats_test)
else:
def _test():
_x, _y = test_iterator()
return sess.run(stats_test, {feeds['x']: _x,
feeds['y']: _y})
m.test = _test
# === Saving and restoring
saver = tf.train.Saver()
saver_ema = tf.train.Saver(ema.variables_to_restore())
m.save_ema = lambda path: saver_ema.save(
sess, path, write_meta_graph=False)
m.save = lambda path: saver.save(sess, path, write_meta_graph=False)
m.restore = lambda path: saver.restore(sess, path)
# === Initialize the parameters
if hps.restore_path != '':
m.restore(hps.restore_path)
else:
with Z.arg_scope([Z.get_variable_ddi, Z.actnorm], init=True):
results_init = f_loss(None, True, reuse=True)
sess.run(tf.global_variables_initializer())
sess.run(results_init, {feeds['x']: data_init['x'],
feeds['y']: data_init['y']})
sess.run(hvd.broadcast_global_variables(0))
return m