in cpp/transfer-learning/main.cpp [123:189]
void train(torch::jit::script::Module net, torch::nn::Linear lin, Dataloader& data_loader, torch::optim::Optimizer& optimizer, size_t dataset_size) {
/*
This function trains the network on our data loader using optimizer.
Also saves the model as model.pt after every epoch.
Parameters
===========
1. net (torch::jit::script::Module type) - Pre-trained model without last FC layer
2. lin (torch::nn::Linear type) - last FC layer with revised out_features depending on the number of classes
3. data_loader (DataLoader& type) - Training data loader
4. optimizer (torch::optim::Optimizer& type) - Optimizer like Adam, SGD etc.
5. size_t (dataset_size type) - Size of training dataset
Returns
===========
Nothing (void)
*/
float best_accuracy = 0.0;
int batch_index = 0;
for(int i=0; i<25; i++) {
float mse = 0;
float Acc = 0.0;
for(auto& batch: *data_loader) {
auto data = batch.data;
auto target = batch.target.squeeze();
// Should be of length: batch_size
data = data.to(torch::kF32);
target = target.to(torch::kInt64);
std::vector<torch::jit::IValue> input;
input.push_back(data);
optimizer.zero_grad();
auto output = net.forward(input).toTensor();
// For transfer learning
output = output.view({output.size(0), -1});
output = lin(output);
auto loss = torch::nll_loss(torch::log_softmax(output, 1), target);
loss.backward();
optimizer.step();
auto acc = output.argmax(1).eq(target).sum();
Acc += acc.template item<float>();
mse += loss.template item<float>();
batch_index += 1;
}
mse = mse/float(batch_index); // Take mean of loss
std::cout << "Epoch: " << i << ", " << "Accuracy: " << Acc/dataset_size << ", " << "MSE: " << mse << std::endl;
test(net, lin, data_loader, dataset_size);
if(Acc/dataset_size > best_accuracy) {
best_accuracy = Acc/dataset_size;
std::cout << "Saving model" << std::endl;
net.save("model.pt");
torch::save(lin, "model_linear.pt");
}
}
}