in src/train.py [0:0]
def main(args):
mpi = False
if 'sagemaker_mpi_enabled' in args.fw_params:
if args.fw_params['sagemaker_mpi_enabled']:
import horovod.keras as hvd
mpi = True
# Horovod: initialize Horovod.
hvd.init()
# Horovod: pin GPU to be used to process local rank (one GPU per process)
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')
else:
hvd = None
callbacks = []
if mpi:
callbacks.append(hvd.callbacks.BroadcastGlobalVariablesCallback(0))
callbacks.append(hvd.callbacks.MetricAverageCallback())
if hvd.rank() == 0:
callbacks.append(ModelCheckpoint(args.output_dir + '/checkpoint-{epoch}.ckpt',
save_weights_only=True,
verbose=2))
else:
callbacks.append(ModelCheckpoint(args.output_dir + '/checkpoint-{epoch}.ckpt',
save_weights_only=True,
verbose=2))
current_host = os.environ['SM_CURRENT_HOST']
print("The current horovod rank is ", hvd.rank())
print("the current host is ", current_host)
print("Training dataset being loaded -----------------")
train_dataset = train_input_fn(hvd, mpi)
print("valid dataset being loaded -----------------")
valid_dataset = valid_input_fn(hvd, mpi)
print("Test dataset being loaded -----------------")
test_dataset = test_input_fn()
logging.info("configuring model")
model = model_def(args.learning_rate, mpi, hvd)
logging.info("Starting training")
size = 1
if mpi:
size = hvd.size()
print("the size is ", size)
# Fit the model
model.fit(train_dataset,
steps_per_epoch=((args.num_train // args.batch_size) // size),
epochs=args.epochs,
validation_data=valid_dataset,
validation_steps=((args.num_val // args.batch_size) // size),
callbacks=callbacks,
verbose=2)
# Evaluate the model at rank 0
if not mpi or (mpi and hvd.rank() == 0):
print("-------------------------Evaluation begins ----------------------------------------------------")
# Accumulate per-slide predictions
pred_dict = {}
for i, element in enumerate(test_dataset):
if (i + 1) % 1000 == 0:
print("Computing scores for tile {}...".format(i + 1))
logging.info("Computing scores for slide {}...".format(i + 1))
image = element[0].numpy()
label = element[1].numpy()
slide = element[2].numpy().decode()
if slide not in pred_dict.keys():
pred_dict[slide] = {'y_true': label, 'y_pred': []}
pred = model.predict(np.expand_dims(image, axis=0))[0]
pred_dict[slide]['y_pred'].append(pred)
# Aggregate per-slide predictions
y_true = []
y_pred = []
for key, value in pred_dict.items():
slide_true = value['y_true']
pred_scores_list = value['y_pred']
mean_pred_scores = np.mean(np.vstack(pred_scores_list), axis=0)
mean_pred_class = np.argmax(mean_pred_scores)
y_true.append(slide_true)
y_pred.append(mean_pred_class)
print('Slide {}: True Label = {}, Prediction = {}'.format(key, slide_true, mean_pred_class))
logging.info('Slide {}: True Label = {}, Prediction = {}'.format(key, slide_true, mean_pred_class))
acc = accuracy_score(y_true, y_pred)
print('Per-Slide Test accuracy: {}'.format(acc))
logging.info('Per-Slide Test accuracy: {}'.format(acc))
if mpi:
if hvd.rank() == 0:
model_path = '{}/00000001'.format(args.model_output_dir)
model.save(model_path)
else:
model_path = '{}/00000001'.format(args.model_output_dir)
model.save(model_path)
model.save(args.model_output_dir)