def train()

in long_term/pace_network.py [0:0]


    def train(self, dataset, sequences_train, sequences_valid, batch_size=40, chunk_length=1000, n_epochs=2000):
        np.random.seed(1234)
        self.model.train()
 
        lr = 0.001
        lr_decay = 0.999
        optimizer = optim.Adam(self.model.parameters(), lr=lr)
        criterion = nn.L1Loss()

        if len(sequences_valid) > 0:
            inputs_valid, outputs_valid = next(self._prepare_next_batch(batch_size, chunk_length, dataset, sequences_valid))
            inputs_valid = torch.from_numpy(inputs_valid)
            outputs_valid = torch.from_numpy(outputs_valid)
            if self.use_cuda:
                inputs_valid = inputs_valid.cuda()
                outputs_valid = outputs_valid.cuda()
                
        losses = []
        valid_losses = []
        start_epoch = 0
        start_time = time.time()
        try:
            for epoch in range(n_epochs):
                batch_loss = 0.0
                N = 0
                for inputs, outputs in self._prepare_next_batch(batch_size, chunk_length, dataset, sequences_train):
                    inputs = torch.from_numpy(inputs)
                    outputs = torch.from_numpy(outputs)
                    if self.use_cuda:
                        inputs = inputs.cuda()
                        outputs = outputs.cuda()

                    optimizer.zero_grad()
                    predicted = self.model(inputs)
                    loss = criterion(predicted, outputs)
                    loss.backward()
                    optimizer.step()

                    batch_loss += loss.item() * inputs.shape[0]
                    N += inputs.shape[0]

                batch_loss /= N
                losses.append(batch_loss)

                if len(sequences_valid) > 0:
                    with torch.no_grad():
                        predicted = self.model(inputs)
                        loss = criterion(predicted, outputs)
                        valid_losses.append(loss.item())
                        print('[%d] loss %.6f valid %.6f' % (epoch + 1, losses[-1], valid_losses[-1]))
                else:
                    print('[%d] loss %.6f' % (epoch + 1, losses[-1]))

                for param_group in optimizer.param_groups:
                    param_group['lr'] *= lr_decay
                epoch += 1
                if epoch > 0 and (epoch+1) % 10 == 0:
                    time_per_epoch = (time.time() - start_time)/(epoch - start_epoch)
                    print('Benchmark:', time_per_epoch, 's per epoch')
                    start_epoch = epoch
                    start_time = time.time()
        except KeyboardInterrupt:
            print('Training aborted.')

        print('Finished Training')
        return losses, valid_losses