in src/runtime/optimizer.cc [256:358]
void AdamOptimizer::update(const Parameter* p)
{
Context ctx = model->config.lg_ctx;
Runtime* runtime = model->config.lg_hlr;
assert(v_values.find(p->region) != v_values.end());
assert(m_values.find(p->region) != m_values.end());
assert(p->owner_op != NULL);
if (p->sync_type == ParameterSyncType::PS) {
TaskLauncher launcher(ADAM_UPD_PS_TASK_ID,
TaskArgument(this, sizeof(AdamOptimizer)),
Predicate::TRUE_PRED, 0/*mapper_id*/,
FFConfig::get_hash_id(std::string(p->owner_op->name)));
// regions[0]: region_grad
launcher.add_region_requirement(
RegionRequirement(p->region_grad,
READ_ONLY, EXCLUSIVE, p->region_grad));
launcher.add_field(0, FID_DATA);
// regions[1]: region
launcher.add_region_requirement(
RegionRequirement(p->region,
READ_WRITE, EXCLUSIVE, p->region));
launcher.add_field(1, FID_DATA);
// regions[2]: w_region
launcher.add_region_requirement(
RegionRequirement(v_values[p->region].region,
READ_WRITE, EXCLUSIVE, v_values[p->region].region));
launcher.add_field(2, FID_DATA);
// regions[3]: m_region
launcher.add_region_requirement(
RegionRequirement(m_values[p->region].region,
READ_WRITE, EXCLUSIVE, m_values[p->region].region));
launcher.add_field(3, FID_DATA);
runtime->execute_task(ctx, launcher);
// Parameter prefetching optimizations to reduce comm. overhead
// Directly send the parameters back to all worker devices after SGD
ArgumentMap argmap;
IndexLauncher index_launcher(PS_PREFETCH_TASK_ID, p->owner_op->task_is,
TaskArgument(NULL, 0), argmap,
Predicate::TRUE_PRED, false/*must*/, 0/*mapper_id*/,
FFConfig::get_hash_id(std::string(p->owner_op->name)));
// regions[0]: region
index_launcher.add_region_requirement(
RegionRequirement(p->part, 0/*projection*/,
READ_ONLY, EXCLUSIVE, p->region));
index_launcher.add_field(0, FID_DATA);
runtime->execute_index_space(ctx, index_launcher);
} else if (p->sync_type == ParameterSyncType::NCCL) {
IndexSpace task_is = p->owner_op->task_is;
assert(task_is != IndexSpace::NO_SPACE);
ArgumentMap argmap;
Domain domain = runtime->get_index_space_domain(ctx, task_is);
switch (domain.get_dim()) {
#define DIMFUNC(DIM) \
case DIM: \
{ \
Rect<DIM> rect = domain; \
ParallelConfig pc; \
model->config.find_parallel_config(DIM, p->owner_op->name, pc); \
int idx = 0; \
for (PointInRectIterator<DIM> it(rect); it(); it++) { \
OpMeta* mp = p->owner_op->meta[idx++]; \
argmap.set_point(*it, TaskArgument(&mp, sizeof(OpMeta*))); \
} \
break; \
}
LEGION_FOREACH_N(DIMFUNC)
#undef DIMFUNC
default:
assert(false);
}
IndexLauncher launcher(ADAM_UPD_NCCL_TASK_ID, task_is,
TaskArgument(this, sizeof(AdamOptimizer)), argmap,
Predicate::TRUE_PRED, false/*must_epoch*/, 0/*mapper_id*/,
FFConfig::get_hash_id(p->owner_op->name));
// regions[0]: region_grad
launcher.add_region_requirement(
RegionRequirement(p->part_grad, 0/*projection id*/,
READ_ONLY, EXCLUSIVE, p->region_grad));
launcher.add_field(0, FID_DATA);
// regions[1]: region
launcher.add_region_requirement(
RegionRequirement(p->part, 0/*projection id*/,
READ_WRITE, EXCLUSIVE, p->region));
launcher.add_field(1, FID_DATA);
// regions[2]: w_region
launcher.add_region_requirement(
RegionRequirement(v_values[p->region].part, 0/*projection id*/,
READ_WRITE, EXCLUSIVE, v_values[p->region].region));
launcher.add_field(2, FID_DATA);
// regions[3]: m_region
launcher.add_region_requirement(
RegionRequirement(m_values[p->region].part, 0/*projection id*/,
READ_WRITE, EXCLUSIVE, m_values[p->region].region));
launcher.add_field(3, FID_DATA);
//MustEpochLauncher must_epoch_launcher;
//must_epoch_launcher.add_index_task(launcher);
FutureMap fm = runtime->execute_index_space(ctx, launcher);
//runtime->execute_must_epoch(ctx, must_epoch_launcher);
runtime->issue_execution_fence(ctx);
} else {
assert(false);
}
}