int main()

in cpp/custom-dataset/custom-dataset.cpp [189:227]


int main() {
  torch::manual_seed(1);

  if (torch::cuda::is_available())
    options.device = torch::kCUDA;
  std::cout << "Running on: "
            << (options.device == torch::kCUDA ? "CUDA" : "CPU") << std::endl;

  auto data = readInfo();

  auto train_set =
      CustomDataset(data.first).map(torch::data::transforms::Stack<>());
  auto train_size = train_set.size().value();
  auto train_loader =
      torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
          std::move(train_set), options.train_batch_size);

  auto test_set =
      CustomDataset(data.second).map(torch::data::transforms::Stack<>());
  auto test_size = test_set.size().value();
  auto test_loader =
      torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
          std::move(test_set), options.test_batch_size);

  Network network;
  network->to(options.device);

  torch::optim::SGD optimizer(
      network->parameters(), torch::optim::SGDOptions(0.001).momentum(0.5));

  for (size_t i = 0; i < options.iterations; ++i) {
    train(network, *train_loader, optimizer, i + 1, train_size);
    std::cout << std::endl;
    test(network, *test_loader, test_size);
    std::cout << std::endl;
  }

  return 0;
}