in src/runtime/model.cc [1145:1265]
Parameter FFModel::create_conv_weight(Op* op,
const int dims[],
DataType data_type,
Initializer* initializer,
bool create_grad,
ParameterSyncType comm_type)
{
Context ctx = config.lg_ctx;
Runtime* runtime = config.lg_hlr;
std::string pcname = op->name;
IndexSpaceT<4> part_is = (IndexSpaceT<4>) get_or_create_task_is(4, pcname);
Rect<4> part_rect = runtime->get_index_space_domain(ctx, part_is);
int num_par_n = part_rect.hi[3] - part_rect.lo[3] + 1;
int num_par_c = part_rect.hi[2] - part_rect.lo[2] + 1;
int num_par_h = part_rect.hi[1] - part_rect.lo[1] + 1;
int num_par_w = part_rect.hi[0] - part_rect.lo[0] + 1;
// Currently assume we do not split over the channel dimension
assert(num_par_c == 1);
Parameter weight;
weight.sync_type = comm_type;
weight.owner_op = op;
weight.numDim = NDIM;
weight.data_type = data_type;
for (int i = 0; i < NDIM; i++)
weight.adim[i] = dims[NDIM-1-i];
FieldSpace fs = runtime->create_field_space(ctx);
FieldAllocator allocator= runtime->create_field_allocator(ctx, fs);
switch (data_type) {
case DT_FLOAT:
allocator.allocate_field(sizeof(float), FID_DATA);
break;
case DT_DOUBLE:
allocator.allocate_field(sizeof(double), FID_DATA);
break;
case DT_INT32:
allocator.allocate_field(sizeof(int), FID_DATA);
break;
default:
assert(false);
}
// Step 1: forward region and partition
if (weight.sync_type == ParameterSyncType::PS) {
Point<NDIM> hi;
for (int i = 0; i < NDIM; i++)
hi[i] = dims[NDIM-1-i]-1;
Rect<NDIM> rect(Point<NDIM>::ZEROES(), hi);
IndexSpaceT<NDIM> is = runtime->create_index_space(ctx, rect);
weight.region = runtime->create_logical_region(ctx, is, fs);
Transform<NDIM, 4> transform;
for (int i = 0; i < NDIM; i++)
for (int j = 0; j < 4; j++)
transform[i][j] = 0;
IndexPartition ip = runtime->create_partition_by_restriction(
ctx, is, part_is, transform, rect);
assert(runtime->is_index_partition_complete(ctx, ip));
weight.part = runtime->get_logical_partition(
ctx, weight.region, ip);
} else if (weight.sync_type == ParameterSyncType::NCCL) {
// Currently only support sample and attribute parallelism for NCCL communication
assert(num_par_c == 1);
Point<NDIM> hi;
for (int i = 0; i < NDIM; i++)
hi[i] = dims[NDIM-1-i]-1;
hi[NDIM-1] = num_par_n * num_par_h * num_par_w * dims[0] - 1;
Rect<NDIM> rect(Point<NDIM>::ZEROES(), hi);
IndexSpaceT<NDIM> is = runtime->create_index_space(ctx, rect);
weight.region = runtime->create_logical_region(ctx, is, fs);
hi[NDIM-1] = dims[0]-1;
Rect<NDIM> extent(Point<NDIM>::ZEROES(), hi);
Transform<NDIM, 4> transform;
for (int i = 0; i < NDIM; i++)
for (int j = 0; j < 4; j++)
transform[i][j] = 0;
transform[NDIM-1][0] = dims[0];
transform[NDIM-1][1] = dims[0] * num_par_w;
transform[NDIM-1][2] = dims[0] * num_par_w * num_par_h;
transform[NDIM-1][3] = dims[0] * num_par_w * num_par_h * num_par_c;
IndexPartition ip = runtime->create_partition_by_restriction(
ctx, is, part_is, transform, extent);
assert(runtime->is_index_partition_complete(ctx, ip));
assert(runtime->is_index_partition_disjoint(ctx, ip));
weight.part = runtime->get_logical_partition(
ctx, weight.region, ip);
} else {
// Unsupported Parameter type
assert(false);
}
// Step 2: initialize region
if (initializer == NULL) {
assert(false); // add weight initializer should be set before
} else {
initializer->init(this, &weight);
}
// Step 3: backward regin and partition
if (create_grad && config.computationMode == COMP_MODE_TRAINING) {
Point<NDIM> hi;
for (int i = 0; i < NDIM; i++)
hi[i] = dims[NDIM-1-i]-1;
hi[NDIM-1] = num_par_n * num_par_h * num_par_w * dims[0] - 1;
Rect<NDIM> rect(Point<NDIM>::ZEROES(), hi);
IndexSpaceT<NDIM> is = runtime->create_index_space(ctx, rect);
weight.region_grad = runtime->create_logical_region(ctx, is, fs);
hi[NDIM-1] = dims[0]-1;
Rect<NDIM> extent(Point<NDIM>::ZEROES(), hi);
Transform<NDIM, 4> transform;
for (int i = 0; i < NDIM; i++)
for (int j = 0; j < 4; j++)
transform[i][j] = 0;
transform[NDIM-1][0] = dims[0];
transform[NDIM-1][1] = dims[0] * num_par_w;
transform[NDIM-1][2] = dims[0] * num_par_w * num_par_h;
transform[NDIM-1][3] = dims[0] * num_par_w * num_par_h * num_par_c;
IndexPartition ip = runtime->create_partition_by_restriction(
ctx, is, part_is, transform, extent);
assert(runtime->is_index_partition_complete(ctx, ip));
assert(runtime->is_index_partition_disjoint(ctx, ip));
weight.part_grad = runtime->get_logical_partition(
ctx, weight.region_grad, ip);
}
return weight;
}