void AdamOptimizer::update()

in src/runtime/optimizer.cc [256:358]


void AdamOptimizer::update(const Parameter* p)
{
  Context ctx = model->config.lg_ctx;
  Runtime* runtime = model->config.lg_hlr;
  assert(v_values.find(p->region) != v_values.end());
  assert(m_values.find(p->region) != m_values.end());
  assert(p->owner_op != NULL);
  if (p->sync_type == ParameterSyncType::PS) {
    TaskLauncher launcher(ADAM_UPD_PS_TASK_ID,
        TaskArgument(this, sizeof(AdamOptimizer)),
        Predicate::TRUE_PRED, 0/*mapper_id*/,
        FFConfig::get_hash_id(std::string(p->owner_op->name)));
    // regions[0]: region_grad
    launcher.add_region_requirement(
        RegionRequirement(p->region_grad,
                          READ_ONLY, EXCLUSIVE, p->region_grad));
    launcher.add_field(0, FID_DATA);
    // regions[1]: region
    launcher.add_region_requirement(
        RegionRequirement(p->region,
                          READ_WRITE, EXCLUSIVE, p->region));
    launcher.add_field(1, FID_DATA);
    // regions[2]: w_region
    launcher.add_region_requirement(
        RegionRequirement(v_values[p->region].region,
                          READ_WRITE, EXCLUSIVE, v_values[p->region].region));
    launcher.add_field(2, FID_DATA);
    // regions[3]: m_region
    launcher.add_region_requirement(
        RegionRequirement(m_values[p->region].region,
                          READ_WRITE, EXCLUSIVE, m_values[p->region].region));
    launcher.add_field(3, FID_DATA);
    runtime->execute_task(ctx, launcher);
    // Parameter prefetching optimizations to reduce comm. overhead
    // Directly send the parameters back to all worker devices after SGD
    ArgumentMap argmap;
    IndexLauncher index_launcher(PS_PREFETCH_TASK_ID, p->owner_op->task_is,
        TaskArgument(NULL, 0), argmap,
        Predicate::TRUE_PRED, false/*must*/, 0/*mapper_id*/,
        FFConfig::get_hash_id(std::string(p->owner_op->name)));
    // regions[0]: region
    index_launcher.add_region_requirement(
        RegionRequirement(p->part, 0/*projection*/,
                          READ_ONLY, EXCLUSIVE, p->region));
    index_launcher.add_field(0, FID_DATA);
    runtime->execute_index_space(ctx, index_launcher);
  } else if (p->sync_type == ParameterSyncType::NCCL) {
    IndexSpace task_is = p->owner_op->task_is;
    assert(task_is != IndexSpace::NO_SPACE);
    ArgumentMap argmap;
    Domain domain = runtime->get_index_space_domain(ctx, task_is);
    switch (domain.get_dim()) {
#define DIMFUNC(DIM) \
      case DIM: \
      { \
        Rect<DIM> rect = domain; \
        ParallelConfig pc; \
        model->config.find_parallel_config(DIM, p->owner_op->name, pc); \
        int idx = 0; \
        for (PointInRectIterator<DIM> it(rect); it(); it++) { \
          OpMeta* mp = p->owner_op->meta[idx++]; \
          argmap.set_point(*it, TaskArgument(&mp, sizeof(OpMeta*))); \
        } \
        break; \
      }
      LEGION_FOREACH_N(DIMFUNC)
#undef DIMFUNC
      default:
        assert(false);
    }
    IndexLauncher launcher(ADAM_UPD_NCCL_TASK_ID, task_is,
        TaskArgument(this, sizeof(AdamOptimizer)), argmap,
        Predicate::TRUE_PRED, false/*must_epoch*/, 0/*mapper_id*/,
        FFConfig::get_hash_id(p->owner_op->name));
    // regions[0]: region_grad
    launcher.add_region_requirement(
        RegionRequirement(p->part_grad, 0/*projection id*/,
                          READ_ONLY, EXCLUSIVE, p->region_grad));
    launcher.add_field(0, FID_DATA);
    // regions[1]: region
    launcher.add_region_requirement(
        RegionRequirement(p->part, 0/*projection id*/,
                          READ_WRITE, EXCLUSIVE, p->region));
    launcher.add_field(1, FID_DATA);
    // regions[2]: w_region
    launcher.add_region_requirement(
        RegionRequirement(v_values[p->region].part, 0/*projection id*/,
                          READ_WRITE, EXCLUSIVE, v_values[p->region].region));
    launcher.add_field(2, FID_DATA);
    // regions[3]: m_region
    launcher.add_region_requirement(
        RegionRequirement(m_values[p->region].part, 0/*projection id*/,
                          READ_WRITE, EXCLUSIVE, m_values[p->region].region));
    launcher.add_field(3, FID_DATA);
    //MustEpochLauncher must_epoch_launcher;
    //must_epoch_launcher.add_index_task(launcher);
    FutureMap fm = runtime->execute_index_space(ctx, launcher);
    //runtime->execute_must_epoch(ctx, must_epoch_launcher);
    runtime->issue_execution_fence(ctx);
  } else {
    assert(false);
  }
}