py::tuple ValidOrdersEncoder::encode_valid_orders()

in dipcc/dipcc/pybind/valid_orders_encoder.h [67:147]


py::tuple ValidOrdersEncoder::encode_valid_orders(const std::string &power_s,
                                                  GameState &state) {
  Power power(power_from_str(power_s));

  // Init return value: all_order_idxs
  py::array_t<int32_t> all_order_idxs({1, MAX_SEQ_LEN, max_cands_});
  memset(all_order_idxs.mutable_data(0, 0, 0), EOS_IDX,
         MAX_SEQ_LEN * max_cands_ * sizeof(int32_t));

  // Init return value: loc_idxs
  py::array_t<int8_t> loc_idxs({1, 81});
  memset(loc_idxs.mutable_data(0, 0), -1, 81 * sizeof(int8_t));

  // Early exit?
  auto orderable_locs_it = state.get_orderable_locations().find(power);
  if (orderable_locs_it == state.get_orderable_locations().end() ||
      orderable_locs_it->second.size() == 0) {
    return py::make_tuple(all_order_idxs, loc_idxs, 0);
  }

  // Get orderable_locs sorted by coast-specific loc idx (orderable_locs returns
  // root_locs)
  auto &all_possible_orders(state.get_all_possible_orders());
  std::vector<Loc> orderable_locs(get_sorted_actual_orderable_locs(
      orderable_locs_it->second, all_possible_orders));

  int n_builds = state.get_n_builds(power);
  if (n_builds > 0) {
    // builds phase
    n_builds = std::min(n_builds, static_cast<int>(orderable_locs.size()));
    std::vector<std::string> orders(get_compound_build_orders(
        all_possible_orders, orderable_locs, n_builds));
    std::vector<int> order_idxs(orders.size());
    for (int j = 0; j < orders.size(); ++j) {
      order_idxs[j] = order_vocabulary_to_idx_.at(orders[j]);
    }
    std::sort(order_idxs.begin(), order_idxs.end());
    for (int j = 0; j < orders.size(); ++j) {
      *all_order_idxs.mutable_data(0, 0, j) = order_idxs[j];
    }
    for (Loc loc : orderable_locs) {
      *loc_idxs.mutable_data(0, static_cast<int>(root_loc(loc)) - 1) = -2;
    }
    return py::make_tuple(all_order_idxs, loc_idxs, n_builds);

  } else if (n_builds < 0) {
    // disband phase
    int n_disbands = -n_builds;
    std::vector<int> order_idxs;
    order_idxs.reserve(orderable_locs.size());
    for (Loc loc : orderable_locs) {
      for (int idx : filter_orders_in_vocab(all_possible_orders.at(loc))) {
        order_idxs.push_back(idx);
      }
    }
    std::sort(order_idxs.begin(), order_idxs.end());
    for (int i = 0; i < n_disbands; ++i) {
      for (int j = 0; j < order_idxs.size(); ++j) {
        *all_order_idxs.mutable_data(0, i, j) = order_idxs[j];
      }
    }
    for (Loc loc : orderable_locs) {
      *loc_idxs.mutable_data(0, static_cast<int>(root_loc(loc)) - 1) = -2;
    }
    return py::make_tuple(all_order_idxs, loc_idxs, n_disbands);

  } else {
    // move or retreat phase
    for (int i = 0; i < orderable_locs.size(); ++i) {
      Loc loc = orderable_locs[i];
      std::vector<int> order_idxs(
          filter_orders_in_vocab(all_possible_orders.at(loc)));
      std::sort(order_idxs.begin(), order_idxs.end());
      for (int j = 0; j < order_idxs.size(); ++j) {
        *all_order_idxs.mutable_data(0, i, j) = order_idxs[j];
        *loc_idxs.mutable_data(0, static_cast<int>(root_loc(loc)) - 1) = i;
      }
    }
    return py::make_tuple(all_order_idxs, loc_idxs, orderable_locs.size());
  }
} // encode_valid_orders