void train()

in cpp/transfer-learning/main.cpp [123:189]


void train(torch::jit::script::Module net, torch::nn::Linear lin, Dataloader& data_loader, torch::optim::Optimizer& optimizer, size_t dataset_size) {
    /*
     This function trains the network on our data loader using optimizer.
     
     Also saves the model as model.pt after every epoch.
     Parameters
     ===========
     1. net (torch::jit::script::Module type) - Pre-trained model without last FC layer
     2. lin (torch::nn::Linear type) - last FC layer with revised out_features depending on the number of classes
     3. data_loader (DataLoader& type) - Training data loader
     4. optimizer (torch::optim::Optimizer& type) - Optimizer like Adam, SGD etc.
     5. size_t (dataset_size type) - Size of training dataset
     
     Returns
     ===========
     Nothing (void)
     */
    float best_accuracy = 0.0; 
    int batch_index = 0;
    
    for(int i=0; i<25; i++) {
        float mse = 0;
        float Acc = 0.0;
        
        for(auto& batch: *data_loader) {
            auto data = batch.data;
            auto target = batch.target.squeeze();
            
            // Should be of length: batch_size
            data = data.to(torch::kF32);
            target = target.to(torch::kInt64);
            
            std::vector<torch::jit::IValue> input;
            input.push_back(data);
            optimizer.zero_grad();
            
            auto output = net.forward(input).toTensor();
            // For transfer learning
            output = output.view({output.size(0), -1});
            output = lin(output);
            
            auto loss = torch::nll_loss(torch::log_softmax(output, 1), target);
            
            loss.backward();
            optimizer.step();
            
            auto acc = output.argmax(1).eq(target).sum();
            
            Acc += acc.template item<float>();
            mse += loss.template item<float>();
            
            batch_index += 1;
        }

        mse = mse/float(batch_index); // Take mean of loss
        std::cout << "Epoch: " << i  << ", " << "Accuracy: " << Acc/dataset_size << ", " << "MSE: " << mse << std::endl;

        test(net, lin, data_loader, dataset_size);

        if(Acc/dataset_size > best_accuracy) {
            best_accuracy = Acc/dataset_size;
            std::cout << "Saving model" << std::endl;
            net.save("model.pt");
            torch::save(lin, "model_linear.pt");
        }
    }
}