Parameter FFModel::create_linear_weight()

in src/runtime/model.cc [1017:1142]


Parameter FFModel::create_linear_weight(Op* op,
                                        const int dims[],
                                        DataType data_type,
                                        Initializer* initializer,
                                        bool create_grad,
                                        ParameterSyncType comm_type)
{
  std::string pcname = op->name;
  IndexSpaceT<TDIM> part_is = (IndexSpaceT<TDIM>)get_or_create_task_is(TDIM, pcname);
  Context ctx = config.lg_ctx;
  Runtime* runtime = config.lg_hlr;
  Rect<TDIM> part_rect = runtime->get_index_space_domain(ctx, part_is);
  int num_parts[TDIM];
  for (int i = 0; i < TDIM; i++)
    num_parts[i] = part_rect.hi[i] - part_rect.lo[i] + 1;
  Parameter weight;
  weight.sync_type = comm_type;
  weight.owner_op = op;
  weight.numDim = NDIM;
  weight.data_type = data_type;
  for (int i = 0; i < NDIM; i++)
    weight.adim[i] = dims[NDIM-1-i];
  FieldSpace fs = runtime->create_field_space(ctx);
  FieldAllocator allocator= runtime->create_field_allocator(ctx, fs);
  switch (data_type) {
    case DT_FLOAT:
      allocator.allocate_field(sizeof(float), FID_DATA);
      break;
    case DT_DOUBLE:
      allocator.allocate_field(sizeof(double), FID_DATA);
      break;
    case DT_INT32:
      allocator.allocate_field(sizeof(int), FID_DATA);
      break;
    default:
      assert(false);
  }
  // Step 1: forward region and partition
  if (weight.sync_type == ParameterSyncType::PS) {
    Point<NDIM> hi;
    for (int i = 0; i < NDIM; i++)
      hi[i] = dims[NDIM-1-i]-1;
    Rect<NDIM> rect(Point<NDIM>::ZEROES(), hi);
    IndexSpaceT<NDIM> is = runtime->create_index_space(ctx, rect);
    weight.region = runtime->create_logical_region(ctx, is, fs);
    assert(dims[0] % num_parts[0] == 0);
    hi[NDIM-1] = dims[0] / num_parts[0] - 1;
    Rect<NDIM> extent(Point<NDIM>::ZEROES(), hi);
    Transform<NDIM, TDIM> transform;
    for (int i = 0; i < NDIM; i++)
      for (int j = 0; j < TDIM; j++)
        transform[i][j] = 0;
    transform[NDIM-1][0] = dims[0] / num_parts[0];
    IndexPartition ip = runtime->create_partition_by_restriction(
        ctx, is, part_is, transform, extent);
    assert(runtime->is_index_partition_complete(ctx, ip));
    weight.part = runtime->get_logical_partition(
        ctx, weight.region, ip);
  } else if (weight.sync_type == ParameterSyncType::NCCL) {
    // FIXME: Currently only support the sample dimension for operators with NCCL
    //for (int i = 0; i < TDIM-1; i++)
    //  assert(num_parts[i] == 1);
    Point<NDIM> hi;
    for (int i = 0; i < NDIM; i++)
      hi[i] = dims[NDIM-1-i]-1;
    int num_batches = 1;
    for (int i = 1; i < TDIM; i++)
      num_batches *= num_parts[i];
    hi[NDIM-1] = num_batches * dims[0] - 1;
    Rect<NDIM> rect(Point<NDIM>::ZEROES(), hi);
    IndexSpaceT<NDIM> is = runtime->create_index_space(ctx, rect);
    weight.region = runtime->create_logical_region(ctx, is, fs);
    hi[NDIM-1] = dims[0] / num_parts[0] - 1;
    Rect<NDIM> extent(Point<NDIM>::ZEROES(), hi);
    Transform<NDIM, TDIM> transform;
    for (int i = 0; i < NDIM; i++)
      for (int j = 0; j < TDIM; j++)
        transform[i][j] = 0;
    transform[NDIM-1][0] = dims[0] / num_parts[0];
    for (int i = 1; i < TDIM; i++)
      transform[NDIM-1][i] = transform[NDIM-1][i-1] * num_parts[i-1];
    IndexPartition ip = runtime->create_partition_by_restriction(
        ctx, is, part_is, transform, extent);
    assert(runtime->is_index_partition_complete(ctx, ip));
    assert(runtime->is_index_partition_disjoint(ctx, ip));
    weight.part = runtime->get_logical_partition(
        ctx, weight.region, ip);
  } else {
    assert(false);
  }
  // Step 2: initialize region
  if (initializer == NULL) {
    assert(false); // add weight initializer should be set before
  } else {
    initializer->init(this, &weight);
  }
  // Step 3: backward region
  if (create_grad && config.computationMode == COMP_MODE_TRAINING) {
    Point<NDIM> hi;
    for (int i = 0; i < NDIM; i++)
      hi[i] = dims[NDIM-1-i]-1;
    int num_batches = 1;
    for (int i = 1; i < TDIM; i++)
      num_batches *= num_parts[i];
    hi[NDIM-1] = num_batches * dims[0] -1;
    Rect<NDIM> rect(Point<NDIM>::ZEROES(), hi);
    IndexSpaceT<NDIM> is = runtime->create_index_space(ctx, rect);
    weight.region_grad = runtime->create_logical_region(ctx, is, fs);
    hi[NDIM-1] = dims[0] / num_parts[0] - 1;
    Rect<NDIM> extent(Point<NDIM>::ZEROES(), hi);
    Transform<NDIM, TDIM> transform;
    for (int i = 0; i < NDIM; i++)
      for (int j = 0; j < TDIM; j++)
        transform[i][j] = 0;
    transform[NDIM-1][0] = dims[0] / num_parts[0];
    for (int i = 1; i < TDIM; i++)
      transform[NDIM-1][i] = transform[NDIM-1][i-1] * num_parts[i-1];
    IndexPartition ip = runtime->create_partition_by_restriction(
        ctx, is, part_is, transform, extent);
    assert(runtime->is_index_partition_complete(ctx, ip));
    assert(runtime->is_index_partition_disjoint(ctx, ip));
    weight.part_grad = runtime->get_logical_partition(
        ctx, weight.region_grad, ip);
  }
  return weight;
}