Status NeuronBatchSharder::Setup()

in runtime/direct/dynamic_batch.cc [36:133]


Status NeuronBatchSharder::Setup(const NeuronExecutableInfo& info,
                                 const std::vector<Tensor>& inputs) {
  // Batch axis argument validity checking
  for (int idx = 0; idx < info.input_batch_axis.i_size(); ++idx) {
    int batch_axis = info.input_batch_axis.i(idx);
    if (TF_PREDICT_FALSE(!IsValidBatchAxis(batch_axis))) {
      return errors::InvalidArgument("Input #", idx, " has invalid batch axis ",
                                     batch_axis);
    }
  }
  for (int idx = 0; idx < info.output_batch_axis.i_size(); ++idx) {
    int batch_axis = info.output_batch_axis.i(idx);
    if (TF_PREDICT_FALSE(!IsValidBatchAxis(batch_axis))) {
      return errors::InvalidArgument("Output #", idx,
                                     " has invalid batch axis ", batch_axis);
    }
  }

  // Initialize client output shapes to NEFF output shapes
  int num_inputs = info.input_shapes.shape_size();
  int num_outputs = info.output_shapes.shape_size();
  client_output_shapes_.reserve(num_outputs);
  for (const TensorShapeProto& neff_shape_proto : info.output_shapes.shape()) {
    TF_RETURN_IF_ERROR(TensorShape::IsValidShape(neff_shape_proto));
    client_output_shapes_.emplace_back(neff_shape_proto);
  }

  // Initialize input/output need-sharding markers to false
  inputs_need_sharding_.resize(num_inputs, false);
  outputs_need_sharding_.resize(num_outputs, false);

  // Read client/NEFF batch size candidates from inputs/input_shapes
  std::unordered_set<int> client_batch_size_set;
  std::unordered_set<int> neff_batch_size_set;
  for (int idx = 0; idx < num_inputs; ++idx) {
    int batch_axis = info.input_batch_axis.i(idx);
    if (!BatchAxisIsDynamic(batch_axis)) {
      continue;
    }

    // Client/NEFF batch sizes
    const Tensor& input_tensor = inputs.at(idx);
    const TensorShapeProto& neff_shape_proto = info.input_shapes.shape(idx);
    TF_RETURN_IF_ERROR(TensorShape::IsValidShape(neff_shape_proto));
    TensorShape neff_shape(neff_shape_proto);
    if (input_tensor.shape() == neff_shape) {
      continue;
    }
    if (TF_PREDICT_FALSE(input_tensor.dims() <= batch_axis + 1)) {
      return errors::InvalidArgument(
          "Input tensor #", idx, " has dynamic batch axis ", batch_axis,
          ", but it only has ", input_tensor.dims(), " dimensions");
    }
    if (TF_PREDICT_FALSE(neff_shape.dims() <= batch_axis + 1)) {
      return errors::InvalidArgument(
          "NEFF input tensor #", idx, " has dynamic batch axis ", batch_axis,
          ", but it only has ", neff_shape.dims(), " dimensions");
    }
    client_batch_size_set.insert(input_tensor.dim_size(batch_axis));
    neff_batch_size_set.insert(neff_shape.dim_size(batch_axis));
    inputs_need_sharding_.at(idx) = true;
  }

  // client_batch_size_set is empty; everything is supposed to have fixed shape
  can_skip_ = client_batch_size_set.empty();
  if (can_skip_) {
    VLOG(1) << "NeuronBatchSharder::Setup done without any dynamic batch size";
    return Status::OK();
  }
  if (TF_PREDICT_FALSE(client_batch_size_set.size() > 1)) {
    return errors::InvalidArgument("Inconsistent client batch sizes");
  }
  if (TF_PREDICT_FALSE(neff_batch_size_set.size() != 1)) {
    return errors::InvalidArgument("Inconsistent NEFF batch sizes");
  }

  // Set has only one element; use it as the client/NEFF batch size
  client_batch_size_ = *client_batch_size_set.begin();
  if (TF_PREDICT_FALSE(client_batch_size_ < 0)) {
    return errors::InvalidArgument("Invalid client batch size ",
                                   client_batch_size_);
  }
  neff_batch_size_ = *neff_batch_size_set.begin();
  if (TF_PREDICT_FALSE(neff_batch_size_ < 0)) {
    return errors::InvalidArgument("Invalid NEFF batch size ",
                                   neff_batch_size_);
  }
  for (int idx = 0; idx < num_outputs; ++idx) {
    int batch_axis = info.output_batch_axis.i(idx);
    bool need_sharding = BatchAxisIsDynamic(batch_axis);
    outputs_need_sharding_.at(idx) = need_sharding;
    if (TF_PREDICT_TRUE(need_sharding)) {
      client_output_shapes_.at(idx).set_dim(batch_axis, client_batch_size_);
    }
  }
  VLOG(1) << "NeuronBatchSharder::Setup done after finding dynamic batch size";
  return Status::OK();
}