in runtime/direct/dynamic_batch.cc [36:133]
Status NeuronBatchSharder::Setup(const NeuronExecutableInfo& info,
const std::vector<Tensor>& inputs) {
// Batch axis argument validity checking
for (int idx = 0; idx < info.input_batch_axis.i_size(); ++idx) {
int batch_axis = info.input_batch_axis.i(idx);
if (TF_PREDICT_FALSE(!IsValidBatchAxis(batch_axis))) {
return errors::InvalidArgument("Input #", idx, " has invalid batch axis ",
batch_axis);
}
}
for (int idx = 0; idx < info.output_batch_axis.i_size(); ++idx) {
int batch_axis = info.output_batch_axis.i(idx);
if (TF_PREDICT_FALSE(!IsValidBatchAxis(batch_axis))) {
return errors::InvalidArgument("Output #", idx,
" has invalid batch axis ", batch_axis);
}
}
// Initialize client output shapes to NEFF output shapes
int num_inputs = info.input_shapes.shape_size();
int num_outputs = info.output_shapes.shape_size();
client_output_shapes_.reserve(num_outputs);
for (const TensorShapeProto& neff_shape_proto : info.output_shapes.shape()) {
TF_RETURN_IF_ERROR(TensorShape::IsValidShape(neff_shape_proto));
client_output_shapes_.emplace_back(neff_shape_proto);
}
// Initialize input/output need-sharding markers to false
inputs_need_sharding_.resize(num_inputs, false);
outputs_need_sharding_.resize(num_outputs, false);
// Read client/NEFF batch size candidates from inputs/input_shapes
std::unordered_set<int> client_batch_size_set;
std::unordered_set<int> neff_batch_size_set;
for (int idx = 0; idx < num_inputs; ++idx) {
int batch_axis = info.input_batch_axis.i(idx);
if (!BatchAxisIsDynamic(batch_axis)) {
continue;
}
// Client/NEFF batch sizes
const Tensor& input_tensor = inputs.at(idx);
const TensorShapeProto& neff_shape_proto = info.input_shapes.shape(idx);
TF_RETURN_IF_ERROR(TensorShape::IsValidShape(neff_shape_proto));
TensorShape neff_shape(neff_shape_proto);
if (input_tensor.shape() == neff_shape) {
continue;
}
if (TF_PREDICT_FALSE(input_tensor.dims() <= batch_axis + 1)) {
return errors::InvalidArgument(
"Input tensor #", idx, " has dynamic batch axis ", batch_axis,
", but it only has ", input_tensor.dims(), " dimensions");
}
if (TF_PREDICT_FALSE(neff_shape.dims() <= batch_axis + 1)) {
return errors::InvalidArgument(
"NEFF input tensor #", idx, " has dynamic batch axis ", batch_axis,
", but it only has ", neff_shape.dims(), " dimensions");
}
client_batch_size_set.insert(input_tensor.dim_size(batch_axis));
neff_batch_size_set.insert(neff_shape.dim_size(batch_axis));
inputs_need_sharding_.at(idx) = true;
}
// client_batch_size_set is empty; everything is supposed to have fixed shape
can_skip_ = client_batch_size_set.empty();
if (can_skip_) {
VLOG(1) << "NeuronBatchSharder::Setup done without any dynamic batch size";
return Status::OK();
}
if (TF_PREDICT_FALSE(client_batch_size_set.size() > 1)) {
return errors::InvalidArgument("Inconsistent client batch sizes");
}
if (TF_PREDICT_FALSE(neff_batch_size_set.size() != 1)) {
return errors::InvalidArgument("Inconsistent NEFF batch sizes");
}
// Set has only one element; use it as the client/NEFF batch size
client_batch_size_ = *client_batch_size_set.begin();
if (TF_PREDICT_FALSE(client_batch_size_ < 0)) {
return errors::InvalidArgument("Invalid client batch size ",
client_batch_size_);
}
neff_batch_size_ = *neff_batch_size_set.begin();
if (TF_PREDICT_FALSE(neff_batch_size_ < 0)) {
return errors::InvalidArgument("Invalid NEFF batch size ",
neff_batch_size_);
}
for (int idx = 0; idx < num_outputs; ++idx) {
int batch_axis = info.output_batch_axis.i(idx);
bool need_sharding = BatchAxisIsDynamic(batch_axis);
outputs_need_sharding_.at(idx) = need_sharding;
if (TF_PREDICT_TRUE(need_sharding)) {
client_output_shapes_.at(idx).set_dim(batch_axis, client_batch_size_);
}
}
VLOG(1) << "NeuronBatchSharder::Setup done after finding dynamic batch size";
return Status::OK();
}