in tbsm_pytorch.py [0:0]
def train_tbsm(args, use_gpu):
# prepare the data
train_ld, _ = tp.make_tbsm_data_and_loader(args, "train")
val_ld, _ = tp.make_tbsm_data_and_loader(args, "val")
# setup initial values
isMainTraining = False
writer = SummaryWriter()
losses = np.empty((0,3), np.float32)
accuracies = np.empty((0,3), np.float32)
# selects best seed out of 5. Sometimes Adagrad gets stuck early, this
# seems to occur randomly depending on initial weight values and
# is independent of chosen model: N-inner, dot etc.
# this procedure is used to reduce the probability of this happening.
def select(args):
seeds = np.random.randint(2, 10000, size=5)
if args.debug_mode:
print(seeds)
best_index = 0
max_val_accuracy = 0.0
testpoint = min(int(0.05 * len(train_ld)), len(train_ld) - 1)
print("testpoint, total batches: ", testpoint, len(train_ld))
for i, seed in enumerate(seeds):
set_seed(seed, use_gpu)
tbsm, device = get_tbsm(args, use_gpu)
gA_test = iterate_train_data(args, train_ld, val_ld, tbsm, 0, use_gpu,
device, writer, losses, accuracies,
isMainTraining)
if args.debug_mode:
print("select: ", i, seed, gA_test, max_val_accuracy)
if gA_test > max_val_accuracy:
best_index = i
max_val_accuracy = gA_test
return seeds[best_index]
# select best seed if needed
if args.no_select_seed or path.exists(args.save_model):
seed = args.numpy_rand_seed
else:
print("Choosing best seed...")
seed = select(args)
set_seed(seed, use_gpu)
print("selected seed:", seed)
# create or load TBSM
tbsm, device = get_tbsm(args, use_gpu)
if args.debug_mode:
print("initial parameters (weights and bias):")
for name, param in tbsm.named_parameters():
print(name)
print(param.detach().cpu().numpy())
# main training loop
isMainTraining = True
print("time/loss/accuracy (if enabled):")
with torch.autograd.profiler.profile(args.enable_profiling, use_gpu) as prof:
for k in range(args.nepochs):
iterate_train_data(args, train_ld, val_ld, tbsm, k, use_gpu, device,
writer, losses, accuracies, isMainTraining)
# collect metrics and other statistics about the run
if args.enable_summary:
with open('summary.npy', 'wb') as acc_loss:
np.save(acc_loss, losses)
np.save(acc_loss, accuracies)
writer.close()
# debug prints
if args.debug_mode:
print("final parameters (weights and bias):")
for name, param in tbsm.named_parameters():
print(name)
print(param.detach().cpu().numpy())
# profiling
if args.enable_profiling:
with open("tbsm_pytorch.prof", "w") as prof_f:
prof_f.write(
prof.key_averages(group_by_input_shape=True).table(
sort_by="self_cpu_time_total"
)
)
prof.export_chrome_trace("./tbsm_pytorch.json")
return