def train_model()

in decoder.py [0:0]


def train_model(coordinates_from, coordinates_to, file_name, method_type, n_epochs=1000, 
	space='poincare', lr=1e-4, wd=1e-3, batch_size=8, n_warmup=3000, cuda=0, tb=0, bn='before', lrm=1.0):

	if method_type == 'decoder':
		space='euclidean'

	if cuda:
		device = th.device("cuda:0" if torch.cuda.is_available() else "cpu")
	else:
		device = th.device("cpu")

	print(f"Computing on {device}")

	encoder = Encoder(n_inputs=coordinates_from.shape[1], n_outputs=coordinates_to.shape[1], bn=bn)
	encoder = encoder.to(device)
	
	optimizer = torch.optim.Adam(encoder.parameters(), lr=lr, weight_decay=wd)    

	loss = torch.nn.MSELoss()
	
	if tb:
		writer = SummaryWriter()

	X_train, X_test, y_train, y_test = train_test_split(coordinates_from, coordinates_to, test_size=0.3, random_state=42)
	if cuda:
		t_X_train = torch.Tensor(X_train).cuda()
		t_y_train = torch.Tensor(y_train).cuda()
		t_X_test = torch.Tensor(X_test).cuda()
		t_y_test = torch.Tensor(y_test).cuda()
		loader = DataLoader(TensorDataset(torch.Tensor(X_train).cuda(), 
							torch.Tensor(y_train).cuda()),
							batch_size=batch_size,
							# pin_memory=True,
							shuffle=True)
	else:
		t_X_train = torch.Tensor(X_train)
		t_y_train = torch.Tensor(y_train)
		t_X_test = torch.Tensor(X_test)
		t_y_test = torch.Tensor(y_test)
		loader = DataLoader(TensorDataset(torch.Tensor(X_train), 
							torch.Tensor(y_train)),
							batch_size=batch_size,
							# pin_memory=True,
							shuffle=True)
	
	poincare_distances = PoincareDistance()	

	train_error = []
	test_error = []

	pbar = tqdm(range(n_epochs), ncols=80)
	t_start = timeit.default_timer()

	n_iter = 0
	for epoch in pbar:
		epoch_error = 0

		# if epoch == 100:
		# 	optimizer = torch.optim.Adam(encoder.parameters(), lr=lr, weight_decay=wd)

		if epoch == n_warmup:
			optimizerSGD = RiemannianSGD(encoder.parameters(), lr=1e-4)

		for inputs, targets in loader:
			preds = encoder(inputs)

			if epoch >= n_warmup and space =='poincare':
				z = PoincareDistance()(preds, targets)
				error_encoder =  th.mean(z)
				optimizerSGD.zero_grad()    
				error_encoder.backward()
				optimizerSGD.step()
			else:
				error_encoder = loss(preds, targets) 
				optimizer.zero_grad()
				error_encoder.backward()
				optimizer.step()

			epoch_error += error_encoder.item()

			if tb:
				writer.add_scalar("data/train/error", error_encoder.item(), n_iter)
				writer.add_histogram("data/train/predictions", preds.data, n_iter)
				writer.add_histogram("data/train/targets", targets.data, n_iter)

			n_iter += 1
		
		pbar.set_description("loss: {:.5e}".format(epoch_error))
		
		encoder.eval()
		if space =='poincare':
			test_error.append(th.mean(poincare_distances(encoder(t_X_test), t_y_test)).detach().cpu() )
			train_error.append(th.mean(poincare_distances(encoder(t_X_train), t_y_train)).detach().cpu() )
		else:
			test_error.append(loss(encoder(t_X_test), t_y_test).detach().cpu() )
			train_error.append(loss(encoder(t_X_train), t_y_train).detach().cpu()  )		

		if epoch % 100 == 0:
			fig = plt.figure(figsize=(5, 5))
			plt.plot(np.log10(train_error), label='train', color='red')
			plt.plot(np.log10(test_error), label='test', color='green')            
			plt.legend(['train', 'test'])
			plt.show()
			plt.savefig(file_name + '_training_error.png', format='png', dpi=150)

		if tb:
			writer.add_scalar("data/test/epoch_error", test_error[-1], epoch)
			writer.add_scalar("data/train/epoch_error", train_error[-1], epoch)
			writer.add_histogram("data/test/predictions", encoder(t_X_test).detach().cpu(), epoch)
			writer.add_histogram("data/test/targets", t_y_test, epoch)

		encoder.train()

	encoder.eval()

	print(f"epoch_error = {epoch_error:.5e}")
	elapsed = timeit.default_timer() - t_start
	print(f"Time: {elapsed:.2f}")

	print(f"Max norm = {torch.max(torch.sum(encoder(t_X_test)**2, dim=1)):.2f}")
	
	encoder = encoder.to("cpu")
	th.save(encoder.state_dict(), f"{file_name}_{method_type}.pth.tar")
	
	if tb:
		writer.close()

	return encoder