in TensorFlow/squeezenet/src/train_squeezenet.py [0:0]
def _run(args):
network = networks.catalogue[args.network](args)
deploy_config = _configure_deployment(args.num_gpus, args.clone_on_cpu)
sess = tf.compat.v1.Session(config=_configure_session())
with tf.device(deploy_config.variables_device()):
global_step = tf.compat.v1.train.create_global_step()
with tf.device(deploy_config.optimizer_device()):
optimizer = tf.compat.v1.train.AdamOptimizer(
learning_rate=args.learning_rate
)
'''Inputs'''
with tf.device(deploy_config.inputs_device()), tf.name_scope('inputs'):
pipeline = inputs.Pipeline(args, sess)
examples, labels = pipeline.data
images = examples['image']
image_splits = tf.split(
value=images,
num_or_size_splits=deploy_config.num_clones,
name='split_images'
)
label_splits = tf.split(
value=labels,
num_or_size_splits=deploy_config.num_clones,
name='split_labels'
)
'''Model Creation'''
model_dp = model_deploy.deploy(
config=deploy_config,
model_fn=_clone_fn,
optimizer=optimizer,
kwargs={
'images': image_splits,
'labels': label_splits,
'index_iter': iter(range(deploy_config.num_clones)),
'network': network,
'is_training': pipeline.is_training
}
)
'''Metrics'''
train_metrics = metrics.Metrics(
labels=labels,
clone_predictions=[clone.outputs['predictions']
for clone in model_dp.clones],
device=deploy_config.variables_device(),
name='training'
)
validation_metrics = metrics.Metrics(
labels=labels,
clone_predictions=[clone.outputs['predictions']
for clone in model_dp.clones],
device=deploy_config.variables_device(),
name='validation',
padded_data=True
)
validation_init_op = tf.group(
pipeline.validation_iterator.initializer,
validation_metrics.reset_op
)
train_op = tf.group(
model_dp.train_op,
train_metrics.update_op
)
'''Summaries'''
with tf.device(deploy_config.variables_device()):
train_writer = tf.compat.v1.summary.FileWriter(args.model_dir, sess.graph)
eval_dir = os.path.join(args.model_dir, 'eval')
eval_writer = tf.compat.v1.summary.FileWriter(eval_dir, sess.graph)
tf.compat.v1.summary.scalar('accuracy', train_metrics.accuracy)
tf.compat.v1.summary.scalar('loss', model_dp.total_loss)
all_summaries = tf.compat.v1.summary.merge_all()
if args.keep_last_n_checkpoints:
'''Model Checkpoints'''
saver = tf.compat.v1.train.Saver(max_to_keep=args.keep_last_n_checkpoints)
save_path = os.path.join(args.model_dir, 'model.ckpt')
'''Model Initialization'''
last_checkpoint = tf.train.latest_checkpoint(args.model_dir)
init_op = tf.group(tf.compat.v1.global_variables_initializer(),
tf.compat.v1.local_variables_initializer())
sess.run(init_op)
if args.keep_last_n_checkpoints and last_checkpoint:
saver.restore(sess, last_checkpoint)
starting_step = sess.run(global_step)
if args.trace:
options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
else:
options = None
run_metadata = None
'''Main Loop'''
for train_step in range(starting_step, args.max_train_steps):
sess.run(train_op, feed_dict=pipeline.training_data, options=options, run_metadata=run_metadata)
'''Summary Hook'''
if args.summary_interval > 0 and train_step % args.summary_interval == 0:
results = sess.run(
fetches={'accuracy': train_metrics.accuracy,
'summary': all_summaries},
feed_dict=pipeline.training_data
)
train_writer.add_summary(results['summary'], train_step)
print('Train Step {:<5}: {:>.4}'
.format(train_step, results['accuracy']))
if args.keep_last_n_checkpoints:
'''Checkpoint Hooks'''
if args.checkpoint_interval > 0 and train_step % args.checkpoint_interval == 0:
saver.save(sess, save_path, global_step)
sess.run(train_metrics.reset_op)
'''Eval Hook'''
if args.validation_interval > 0 and train_step % args.validation_interval == 0:
while True:
try:
sess.run(
fetches=validation_metrics.update_op,
feed_dict=pipeline.validation_data
)
except tf.errors.OutOfRangeError:
break
results = sess.run({'accuracy': validation_metrics.accuracy})
print('Evaluation Step {:<5}: {:>.4}'
.format(train_step, results['accuracy']))
summary = tf.compat.v1.Summary(value=[
tf.compat.v1.Summary.Value(
tag='accuracy', simple_value=results['accuracy']),
])
eval_writer.add_summary(summary, train_step)
sess.run(validation_init_op) # Reinitialize dataset and metrics
if args.trace:
fetched_timeline = timeline.Timeline(run_metadata.step_stats)
chrome_trace = fetched_timeline.generate_chrome_trace_format()
with open(os.path.join(args.model_dir, f'cifar_trace_{train_step}.json'), 'w') as f:
f.write(chrome_trace)