in decoder.py [0:0]
def train_model(coordinates_from, coordinates_to, file_name, method_type, n_epochs=1000,
space='poincare', lr=1e-4, wd=1e-3, batch_size=8, n_warmup=3000, cuda=0, tb=0, bn='before', lrm=1.0):
if method_type == 'decoder':
space='euclidean'
if cuda:
device = th.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
device = th.device("cpu")
print(f"Computing on {device}")
encoder = Encoder(n_inputs=coordinates_from.shape[1], n_outputs=coordinates_to.shape[1], bn=bn)
encoder = encoder.to(device)
optimizer = torch.optim.Adam(encoder.parameters(), lr=lr, weight_decay=wd)
loss = torch.nn.MSELoss()
if tb:
writer = SummaryWriter()
X_train, X_test, y_train, y_test = train_test_split(coordinates_from, coordinates_to, test_size=0.3, random_state=42)
if cuda:
t_X_train = torch.Tensor(X_train).cuda()
t_y_train = torch.Tensor(y_train).cuda()
t_X_test = torch.Tensor(X_test).cuda()
t_y_test = torch.Tensor(y_test).cuda()
loader = DataLoader(TensorDataset(torch.Tensor(X_train).cuda(),
torch.Tensor(y_train).cuda()),
batch_size=batch_size,
# pin_memory=True,
shuffle=True)
else:
t_X_train = torch.Tensor(X_train)
t_y_train = torch.Tensor(y_train)
t_X_test = torch.Tensor(X_test)
t_y_test = torch.Tensor(y_test)
loader = DataLoader(TensorDataset(torch.Tensor(X_train),
torch.Tensor(y_train)),
batch_size=batch_size,
# pin_memory=True,
shuffle=True)
poincare_distances = PoincareDistance()
train_error = []
test_error = []
pbar = tqdm(range(n_epochs), ncols=80)
t_start = timeit.default_timer()
n_iter = 0
for epoch in pbar:
epoch_error = 0
# if epoch == 100:
# optimizer = torch.optim.Adam(encoder.parameters(), lr=lr, weight_decay=wd)
if epoch == n_warmup:
optimizerSGD = RiemannianSGD(encoder.parameters(), lr=1e-4)
for inputs, targets in loader:
preds = encoder(inputs)
if epoch >= n_warmup and space =='poincare':
z = PoincareDistance()(preds, targets)
error_encoder = th.mean(z)
optimizerSGD.zero_grad()
error_encoder.backward()
optimizerSGD.step()
else:
error_encoder = loss(preds, targets)
optimizer.zero_grad()
error_encoder.backward()
optimizer.step()
epoch_error += error_encoder.item()
if tb:
writer.add_scalar("data/train/error", error_encoder.item(), n_iter)
writer.add_histogram("data/train/predictions", preds.data, n_iter)
writer.add_histogram("data/train/targets", targets.data, n_iter)
n_iter += 1
pbar.set_description("loss: {:.5e}".format(epoch_error))
encoder.eval()
if space =='poincare':
test_error.append(th.mean(poincare_distances(encoder(t_X_test), t_y_test)).detach().cpu() )
train_error.append(th.mean(poincare_distances(encoder(t_X_train), t_y_train)).detach().cpu() )
else:
test_error.append(loss(encoder(t_X_test), t_y_test).detach().cpu() )
train_error.append(loss(encoder(t_X_train), t_y_train).detach().cpu() )
if epoch % 100 == 0:
fig = plt.figure(figsize=(5, 5))
plt.plot(np.log10(train_error), label='train', color='red')
plt.plot(np.log10(test_error), label='test', color='green')
plt.legend(['train', 'test'])
plt.show()
plt.savefig(file_name + '_training_error.png', format='png', dpi=150)
if tb:
writer.add_scalar("data/test/epoch_error", test_error[-1], epoch)
writer.add_scalar("data/train/epoch_error", train_error[-1], epoch)
writer.add_histogram("data/test/predictions", encoder(t_X_test).detach().cpu(), epoch)
writer.add_histogram("data/test/targets", t_y_test, epoch)
encoder.train()
encoder.eval()
print(f"epoch_error = {epoch_error:.5e}")
elapsed = timeit.default_timer() - t_start
print(f"Time: {elapsed:.2f}")
print(f"Max norm = {torch.max(torch.sum(encoder(t_X_test)**2, dim=1)):.2f}")
encoder = encoder.to("cpu")
th.save(encoder.state_dict(), f"{file_name}_{method_type}.pth.tar")
if tb:
writer.close()
return encoder