in src/train.py [0:0]
def validate_batch(train_config: TrainConfig, epoch, train_loss, model_idx=-1):
accuracies = []
losses = []
c_file = train_config.config_file
for sample_data in tqdm(train_config.valid_data_loader, desc='validating batch', position=0, leave=True):
# get sample input data
sample_data_wrapper = create_sample_wrapper(sample_data, train_config)
# inference
if model_idx == -1:
outs, inference_dicts = train_config.inference(sample_data_wrapper, gradient=False)
out = outs[-1]
inference_dict = inference_dicts[-1]
else:
f_in = train_config.f_in[model_idx]
model = train_config.models[model_idx]
prev_outs = []
for prev_model_idx in range(0, model_idx):
prev_outs.append(sample_data_wrapper.get_train_target(prev_model_idx))
inference_dict = f_in.batch(sample_data_wrapper.get_batch_input(model_idx), prev_outs=prev_outs)
x_batch = inference_dict[FeatureSetKeyConstants.input_feature_batch]
out = model(x_batch)
y_batch = sample_data_wrapper.get_train_target(model_idx)
y_batch = y_batch.reshape(c_file.batchImages * c_file.samples, -1)
loss_batch = train_config.losses[model_idx](out, y_batch, inference_dict=inference_dict)
losses.append(loss_batch)
diff = abs(out - y_batch)
accuracy = float((diff < 0.001).sum()) / float(diff.shape[0] * diff.shape[1])
accuracies.append(accuracy)
loss = torch.mean(torch.tensor(losses)).item()
accuracy = torch.mean(torch.tensor(accuracies)).item()
print(f"\nvalidation epoch={epoch:<10} loss={loss:.8f} acc={accuracy:.8f}")
with open(f"{train_config.logDir}logs.txt", "a") as f:
f.write(f"epoch={epoch} loss={loss:.4f} acc={accuracy:.8f} train_loss={train_loss:.8f}\r")
add_header = False
if not os.path.isfile(f"{train_config.logDir}{c_file.trainStatsName}"):
add_header = True
with open(f"{train_config.logDir}{c_file.trainStatsName}", 'a', newline='') as csv_file:
fieldnames = ['epoch', 'loss', 'accuracy', 'train_loss']
writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
if add_header:
writer.writeheader()
writer.writerow({'epoch': f'{epoch}', 'loss': f'{loss}', 'accuracy': f'{accuracy}',
'train_loss': f'{train_loss}'})
plot_training_stats(train_config.logDir, c_file.trainStatsName, 'epoch', 'loss')
plot_training_stats(train_config.logDir, c_file.trainStatsName, 'epoch', 'train_loss')
plot_training_stats(train_config.logDir, c_file.trainStatsName, 'epoch', 'accuracy')
plot_training_stats(train_config.logDir, c_file.trainStatsName, 'epoch', ['loss', 'train_loss', 'accuracy'])
plot_training_stats(train_config.logDir, c_file.trainStatsName, 'epoch', ['loss', 'train_loss'])
return loss