in train.py [0:0]
def restore_weights_and_initialize(opt):
var_list = tf.trainable_variables()
g_list = tf.global_variables()
# add batch normalization params into trainable variables
bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
var_list +=bn_moving_vars
# create saver to save and restore weights
resnet_vars = [v for v in var_list if 'resnet_v1_50' in v.name]
facenet_vars = [v for v in var_list if 'InceptionResnetV1' in v.name]
saver_resnet = tf.train.Saver(var_list = resnet_vars)
saver_facenet = tf.train.Saver(var_list = facenet_vars)
saver = tf.train.Saver(var_list = resnet_vars + [v for v in var_list if 'fc-' in v.name],max_to_keep = 50)
# create session
sess = tf.InteractiveSession(config = opt.config)
# create summary op
train_writer = tf.summary.FileWriter(opt.train_summary_path, sess.graph)
val_writer = tf.summary.FileWriter(opt.val_summary_path, sess.graph)
# initialization
tf.global_variables_initializer().run()
tf.local_variables_initializer().run()
saver_resnet.restore(sess,opt.R_net_weights)
saver_facenet.restore(sess,opt.Perceptual_net_weights)
return saver, train_writer,val_writer, sess