def train_model()

in ml3/mbrl_utils.py [0:0]


    def train_model(self, training_data):
        train_loader = torch.utils.data.DataLoader(
            training_data, batch_size=64, num_workers=0
        )
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        loss_fn = NLLLoss()
        for epoch in range(self.epochs):
            losses = []
            for batch, (data, target) in enumerate(
                train_loader, 1
            ):  # This is the training loader
                x = data.type(torch.FloatTensor).to(device=self.device)
                y = target.type(torch.FloatTensor).to(device=self.device)

                if x.dim() == 1:
                    x = x.unsqueeze(0).t()
                if y.dim() == 1:
                    y = y.unsqueeze(0).t()

                py = self.forward(x)
                loss = loss_fn(py, y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses.append(loss.item())

            if epoch % self.display_epoch == 0:
                print(
                    colored(
                        "epoch={}, loss={}".format(epoch, np.mean(losses)), "yellow"
                    )
                )