py::dict BAHelpers::BundleShotPoses()

in opensfm/src/sfm/src/ba_helpers.cc [383:556]


py::dict BAHelpers::BundleShotPoses(
    map::Map& map, const std::unordered_set<map::ShotId>& shot_ids,
    const std::unordered_map<map::CameraId, geometry::Camera>& camera_priors,
    const std::unordered_map<map::RigCameraId, map::RigCamera>&
        rig_camera_priors,
    const py::dict& config) {
  py::dict report;

  constexpr auto fix_cameras = true;
  constexpr auto fix_points = true;
  constexpr auto fix_rig_camera = true;

  auto ba = bundle::BundleAdjuster();
  ba.SetUseAnalyticDerivatives(
      config["bundle_analytic_derivatives"].cast<bool>());
  const auto start = std::chrono::high_resolution_clock::now();

  // gather required rig data to setup
  std::unordered_set<map::RigInstanceId> rig_instances_ids;
  for (const auto& shot_id : shot_ids) {
    const auto& shot = map.GetShot(shot_id);
    rig_instances_ids.insert(shot.GetRigInstanceId());
  }
  std::unordered_set<map::RigCameraId> rig_cameras_ids;
  for (const auto& rig_instance_id : rig_instances_ids) {
    auto& instance = map.GetRigInstance(rig_instance_id);
    for (const auto& shot_n_rig_camera : instance.GetRigCameras()) {
      rig_cameras_ids.insert(shot_n_rig_camera.second->id);
    }
  }

  // rig cameras are going to be fixed
  for (const auto& rig_camera_id : rig_cameras_ids) {
    const auto& rig_camera = map.GetRigCamera(rig_camera_id);
    ba.AddRigCamera(rig_camera_id, rig_camera.pose,
                    rig_camera_priors.at(rig_camera_id).pose, fix_rig_camera);
  }

  std::unordered_set<map::CameraId> added_cameras;
  for (const auto& shot_id : shot_ids) {
    const auto& shot = map.GetShot(shot_id);
    const auto& cam = *shot.GetCamera();
    if (added_cameras.find(cam.id) != added_cameras.end()) {
      continue;
    }
    const auto& cam_prior = camera_priors.at(cam.id);
    ba.AddCamera(cam.id, cam, cam_prior, fix_cameras);
    added_cameras.insert(cam.id);
  }

  std::unordered_set<map::Landmark*> landmarks;
  for (const auto& shot_id : shot_ids) {
    const auto& shot = map.GetShot(shot_id);
    for (const auto& lm_obs : shot.GetLandmarkObservations()) {
      landmarks.insert(lm_obs.first);
    }
  }
  for (const auto& landmark : landmarks) {
    ba.AddPoint(landmark->id_, landmark->GetGlobalPos(), fix_points);
  }

  // add rig instances shots
  const std::string gps_scale_group = "dummy";  // unused for now
  for (const auto& rig_instance_id : rig_instances_ids) {
    auto& instance = map.GetRigInstance(rig_instance_id);
    std::unordered_map<std::string, std::string> shot_cameras, shot_rig_cameras;

    // we're going to assign GPS constraint to the instance itself
    // by averaging its shot's GPS values (and std dev.)
    Vec3d average_position = Vec3d::Zero();
    double average_std = 0.;
    int gps_count = 0;

    // if any instance's shot is in boundary
    // then the entire instance will be fixed
    bool fix_instance = false;

    for (const auto& shot_n_rig_camera : instance.GetRigCameras()) {
      const auto shot_id = shot_n_rig_camera.first;
      auto& shot = map.GetShot(shot_id);
      shot_cameras[shot_id] = shot.GetCamera()->id;
      shot_rig_cameras[shot_id] = shot_n_rig_camera.second->id;

      const auto is_fixed = shot_ids.find(shot_id) != shot_ids.end();
      if (!is_fixed) {
        if (config["bundle_use_gps"].cast<bool>()) {
          const auto pos = shot.GetShotMeasurements().gps_position_;
          const auto acc = shot.GetShotMeasurements().gps_accuracy_;
          if (pos.HasValue() && acc.HasValue()) {
            average_position += pos.Value();
            average_std += acc.Value();
            ++gps_count;
          }
        }
      } else {
        fix_instance = true;
      }

      ba.AddRigInstance(rig_instance_id, instance.GetPose(), shot_cameras,
                        shot_rig_cameras, fix_instance);

      // only add averaged rig position constraints to moving instances
      if (!fix_instance && gps_count > 0) {
        average_position /= gps_count;
        average_std /= gps_count;
        ba.AddRigInstancePositionPrior(rig_instance_id, average_position,
                                       Vec3d::Constant(average_std),
                                       gps_scale_group);
      }
    }
  }

  // add observations
  for (const auto& shot_id : shot_ids) {
    const auto& shot = map.GetShot(shot_id);
    for (const auto& lm_obs : shot.GetLandmarkObservations()) {
      const auto& obs = lm_obs.second;
      ba.AddPointProjectionObservation(shot.id_, lm_obs.first->id_, obs.point,
                                       obs.scale);
    }
  }

  ba.SetPointProjectionLossFunction(
      config["loss_function"].cast<std::string>(),
      config["loss_function_threshold"].cast<double>());
  ba.SetInternalParametersPriorSD(
      config["exif_focal_sd"].cast<double>(),
      config["principal_point_sd"].cast<double>(),
      config["radial_distortion_k1_sd"].cast<double>(),
      config["radial_distortion_k2_sd"].cast<double>(),
      config["tangential_distortion_p1_sd"].cast<double>(),
      config["tangential_distortion_p2_sd"].cast<double>(),
      config["radial_distortion_k3_sd"].cast<double>(),
      config["radial_distortion_k4_sd"].cast<double>());
  ba.SetRigParametersPriorSD(config["rig_translation_sd"].cast<double>(),
                             config["rig_rotation_sd"].cast<double>());

  ba.SetNumThreads(config["processes"].cast<int>());
  ba.SetMaxNumIterations(10);
  ba.SetLinearSolverType("DENSE_QR");
  const auto timer_setup = std::chrono::high_resolution_clock::now();

  {
    py::gil_scoped_release release;
    ba.Run();
  }

  const auto timer_run = std::chrono::high_resolution_clock::now();

  for (const auto& rig_instance_id : rig_instances_ids) {
    auto& instance = map.GetRigInstance(rig_instance_id);
    auto i = ba.GetRigInstance(rig_instance_id);
    instance.SetPose(i.GetValue());
  }

  const auto timer_teardown = std::chrono::high_resolution_clock::now();
  report["brief_report"] = ba.BriefReport();
  report["wall_times"] = py::dict();
  report["wall_times"]["setup"] =
      std::chrono::duration_cast<std::chrono::microseconds>(timer_setup - start)
          .count() /
      1000000.0;
  report["wall_times"]["run"] =
      std::chrono::duration_cast<std::chrono::microseconds>(timer_run -
                                                            timer_setup)
          .count() /
      1000000.0;
  report["wall_times"]["teardown"] =
      std::chrono::duration_cast<std::chrono::microseconds>(timer_teardown -
                                                            timer_run)
          .count() /
      1000000.0;
  return report;
}