in ml3/mbrl_utils.py [0:0]
def train_model(self, training_data):
train_loader = torch.utils.data.DataLoader(
training_data, batch_size=64, num_workers=0
)
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
loss_fn = NLLLoss()
for epoch in range(self.epochs):
losses = []
for batch, (data, target) in enumerate(
train_loader, 1
): # This is the training loader
x = data.type(torch.FloatTensor).to(device=self.device)
y = target.type(torch.FloatTensor).to(device=self.device)
if x.dim() == 1:
x = x.unsqueeze(0).t()
if y.dim() == 1:
y = y.unsqueeze(0).t()
py = self.forward(x)
loss = loss_fn(py, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
if epoch % self.display_epoch == 0:
print(
colored(
"epoch={}, loss={}".format(epoch, np.mean(losses)), "yellow"
)
)