in tbsm_pytorch.py [0:0]
def iterate_train_data(args, train_ld, val_ld, tbsm, k, use_gpu, device, writer, losses, accuracies, isMainTraining):
# select number of batches
if isMainTraining:
nbatches = len(train_ld) if args.num_batches == 0 else args.num_batches
else:
nbatches = len(train_ld)
# specify the optimizer algorithm
optimizer = torch.optim.Adagrad(tbsm.parameters(), lr=args.learning_rate)
total_time = 0
total_loss = 0
total_accu = 0
total_iter = 0
total_samp = 0
max_gA_test = 0
for j, (X, lS_o, lS_i, T) in enumerate(train_ld):
if j >= nbatches:
break
t1 = time_wrap(use_gpu, device)
batchSize = X[0].shape[0]
# forward pass
Z = tbsm(*data_wrap(X,
lS_o,
lS_i,
use_gpu,
device
))
# loss
E = loss_fn_wrap(Z, T, use_gpu, device)
# compute loss and accuracy
L = E.detach().cpu().numpy() # numpy array
z = Z.detach().cpu().numpy() # numpy array
t = T.detach().cpu().numpy() # numpy array
# rounding t
A = np.sum((np.round(z, 0) == np.round(t, 0)).astype(np.uint8))
optimizer.zero_grad()
# backward pass
E.backward(retain_graph=True)
# weights update
optimizer.step()
t2 = time_wrap(use_gpu, device)
total_time += t2 - t1
total_loss += (L * batchSize)
total_accu += A
total_iter += 1
total_samp += batchSize
print_tl = ((j + 1) % args.print_freq == 0) or (j + 1 == nbatches)
# print time, loss and accuracy
if print_tl and isMainTraining:
gT = 1000.0 * total_time / total_iter if args.print_time else -1
total_time = 0
gL = total_loss / total_samp
total_loss = 0
gA = total_accu / total_samp
total_accu = 0
str_run_type = "inference" if args.inference_only else "training"
print(
"Finished {} it {}/{} of epoch {}, ".format(
str_run_type, j + 1, nbatches, k
)
+ "{:.2f} ms/it, loss {:.8f}, accuracy {:3.3f} %".format(
gT, gL, gA * 100
)
)
total_iter = 0
total_samp = 0
if isMainTraining:
should_test = (
(args.test_freq > 0
and (j + 1) % args.test_freq == 0) or j + 1 == nbatches
)
else:
should_test = (j == min(int(0.05 * len(train_ld)), len(train_ld) - 1))
# validation run
if should_test:
total_accu_test, total_samp_test, total_loss_val = iterate_val_data(val_ld, tbsm, use_gpu, device)
gA_test = total_accu_test / total_samp_test
if not isMainTraining:
break
gL_test = total_loss_val / total_samp_test
print("At epoch {:d} validation accuracy is {:3.3f} %".
format(k, gA_test * 100))
if args.enable_summary and isMainTraining:
writer.add_scalars('train and val loss',
{'train_loss': gL,
'val_loss': gL_test},
k * len(train_ld) + j)
writer.add_scalars('train and val accuracy',
{'train_acc': gA * 100,
'val_acc': gA_test * 100},
k * len(train_ld) + j)
losses = np.append(losses, np.array([[j, gL, gL_test]]),
axis=0)
accuracies = np.append(accuracies, np.array([[j, gA * 100,
gA_test * 100]]), axis=0)
# save model if best so far
if gA_test > max_gA_test and isMainTraining:
print("Saving current model...")
max_gA_test = gA_test
model_ = tbsm
torch.save(
{
"model_state_dict": model_.state_dict(),
# "opt_state_dict": optimizer.state_dict(),
},
args.save_model,
)
if not isMainTraining:
return gA_test