def fit()

in src/torch_wrapper.py [0:0]


    def fit(self, X, y, sampleweights, n_epochs, loss_type='BCE'):
        """
        Fits the model using the entire sample data as the batch size
        """
        X = torch.from_numpy(X)
        y = torch.from_numpy(y).double()
        self.model.train()  # Puts model in training mode so it updates itself

        # Binary Cross-Entropy Loss with sample weights
        criterion = nn.BCEWithLogitsLoss(weight=torch.from_numpy(sampleweights))  # convert weights to tensor

        for epoch in range(n_epochs):
            self.optimizer.zero_grad()  # Set gradients to 0 before back propagation for this epoch
            # Forward pass
            y_pred = self.model(X)
            # Compute Loss
            loss = criterion(y_pred.squeeze(), y)
            # print(f'Epoch {epoch}: train loss: {loss.item()}')
            # Backward pass
            loss.backward()
            self.optimizer.step()

        return self