in train.py [0:0]
def train():
# read BFM face model
# transfer original BFM model to our model
if not os.path.isfile('./BFM/BFM_model_front.mat'):
transferBFM09()
with tf.Graph().as_default() as graph:
# training options
args = parse_args()
opt = Option(model_name=args.model_name)
opt.data_path = [args.data_path]
opt.val_data_path = [args.val_data_path]
# load training data into queue
train_iterator = load_dataset(opt)
# create reconstruction model
model = Reconstruction_model(opt)
# send training data to the model
model.set_input(train_iterator)
# update model variables with training data
model.step(is_train = True)
# summarize training statistics
model.summarize()
# several training stattistics to be saved
train_stat = model.summary_stat
train_img_stat = model.summary_img
train_op = model.train_op
photo_error = model.photo_loss
lm_error = model.landmark_loss
id_error = model.perceptual_loss
# load validation data into queue
val_iterator = load_dataset(opt,train=False)
# send validation data to the model
model.set_input(val_iterator)
# only do foward pass without updating model variables
model.step(is_train = False)
# summarize validation statistics
model.summarize()
val_stat = model.summary_stat
val_img_stat = model.summary_img
# initialization
saver, train_writer,val_writer, sess = restore_weights_and_initialize(opt)
# freeze the graph to ensure no new op will be added during training
sess.graph.finalize()
# training loop
for i in range(opt.train_maxiter):
_,ph_loss,lm_loss,id_loss = sess.run([train_op,photo_error,lm_error,id_error])
print('Iter: %d; lm_loss: %f ; photo_loss: %f; id_loss: %f\n'%(i,np.sqrt(lm_loss),ph_loss,id_loss))
# summarize training stats every <train_summary_iter> iterations
if np.mod(i,opt.train_summary_iter) == 0:
train_summary = sess.run(train_stat)
train_writer.add_summary(train_summary,i)
# summarize image stats every <image_summary_iter> iterations
if np.mod(i,opt.image_summary_iter) == 0:
train_img_summary = sess.run(train_img_stat)
train_writer.add_summary(train_img_summary,i)
# summarize validation stats every <val_summary_iter> iterations
if np.mod(i,opt.val_summary_iter) == 0:
val_summary,val_img_summary = sess.run([val_stat,val_img_stat])
val_writer.add_summary(val_summary,i)
val_writer.add_summary(val_img_summary,i)
# # save model variables every <save_iter> iterations
if np.mod(i,opt.save_iter) == 0:
saver.save(sess,os.path.join(opt.model_save_path,'iter_%d.ckpt'%i))