in reconstruction_model.py [0:0]
def backward(self,is_train = True):
if is_train:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
var_list = tf.trainable_variables()
update_var_list = [v for v in var_list if 'resnet_v1_50' in v.name or 'fc-' in v.name]
grads = tf.gradients(self.loss,update_var_list)
# get train_op with update_ops to ensure updating for bn parameters
with tf.control_dependencies(update_ops):
self.train_op = self.Optimizer.apply_gradients(zip(grads,update_var_list),global_step = self.opt.global_step)
# if not training stage, avoid updating variables
else:
pass