int main()

in cpp/distributed/dist-mnist.cpp [50:185]


int main(int argc, char* argv[]) {
  // Creating MPI Process Group
  auto pg = c10d::ProcessGroupMPI::createProcessGroupMPI();

  // Retrieving MPI environment variables
  auto numranks = pg->getSize();
  auto rank = pg->getRank();

  // TRAINING
  // Read train dataset
  const char* kDataRoot = "../data";
  auto train_dataset =
      torch::data::datasets::MNIST(kDataRoot)
          .map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
          .map(torch::data::transforms::Stack<>());

  // Distributed Random Sampler
  auto data_sampler = torch::data::samplers::DistributedRandomSampler(
      train_dataset.size().value(), numranks, rank, false);

  auto num_train_samples_per_proc = train_dataset.size().value() / numranks;

  // Generate dataloader
  auto total_batch_size = 64;
  auto batch_size_per_proc =
      total_batch_size / numranks; // effective batch size in each processor
  auto data_loader = torch::data::make_data_loader(
      std::move(train_dataset), data_sampler, batch_size_per_proc);

  // setting manual seed
  torch::manual_seed(0);

  auto model = std::make_shared<Model>();

  auto learning_rate = 1e-2;

  torch::optim::SGD optimizer(model->parameters(), learning_rate);

  // Number of epochs
  size_t num_epochs = 10;

  for (size_t epoch = 1; epoch <= num_epochs; ++epoch) {
    size_t num_correct = 0;

    for (auto& batch : *data_loader) {
      auto ip = batch.data;
      auto op = batch.target.squeeze();

      // convert to required formats
      ip = ip.to(torch::kF32);
      op = op.to(torch::kLong);

      // Reset gradients
      model->zero_grad();

      // Execute forward pass
      auto prediction = model->forward(ip);

      auto loss = torch::nll_loss(torch::log_softmax(prediction, 1), op);

      // Backpropagation
      loss.backward();

      // Averaging the gradients of the parameters in all the processors
      // Note: This may lag behind DistributedDataParallel (DDP) in performance
      // since this synchronizes parameters after backward pass while DDP
      // overlaps synchronizing parameters and computing gradients in backward
      // pass
      std::vector<std::shared_ptr<::c10d::ProcessGroup::Work>> works;
      for (auto& param : model->named_parameters()) {
        std::vector<torch::Tensor> tmp = {param.value().grad()};
        auto work = pg->allreduce(tmp);
        works.push_back(std::move(work));
      }

      waitWork(pg, works);

      for (auto& param : model->named_parameters()) {
        param.value().grad().data() = param.value().grad().data() / numranks;
      }

      // Update parameters
      optimizer.step();

      auto guess = prediction.argmax(1);
      num_correct += torch::sum(guess.eq_(op)).item<int64_t>();
    } // end batch loader

    auto accuracy = 100.0 * num_correct / num_train_samples_per_proc;

    std::cout << "Accuracy in rank " << rank << " in epoch " << epoch << " - "
              << accuracy << std::endl;

  } // end epoch

  // TESTING ONLY IN RANK 0
  if (rank == 0) {
    auto test_dataset =
        torch::data::datasets::MNIST(
            kDataRoot, torch::data::datasets::MNIST::Mode::kTest)
            .map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
            .map(torch::data::transforms::Stack<>());

    auto num_test_samples = test_dataset.size().value();
    auto test_loader = torch::data::make_data_loader(
        std::move(test_dataset), num_test_samples);

    model->eval(); // enable eval mode to prevent backprop

    size_t num_correct = 0;

    for (auto& batch : *test_loader) {
      auto ip = batch.data;
      auto op = batch.target.squeeze();

      // convert to required format
      ip = ip.to(torch::kF32);
      op = op.to(torch::kLong);

      auto prediction = model->forward(ip);

      auto loss = torch::nll_loss(torch::log_softmax(prediction, 1), op);

      std::cout << "Test loss - " << loss.item<float>() << std::endl;

      auto guess = prediction.argmax(1);

      num_correct += torch::sum(guess.eq_(op)).item<int64_t>();

    } // end test loader

    std::cout << "Num correct - " << num_correct << std::endl;
    std::cout << "Test Accuracy - " << 100.0 * num_correct / num_test_samples
              << std::endl;
  } // end rank 0
}