in tensorflow_networking/mpi_collectives/kernels/mpi_ops.cc [294:421]
MPIResponse ConstructMPIResponse(std::unique_ptr<MessageTable>& message_table,
std::string name) {
bool error = false;
auto it = message_table->find(name);
assert(it != message_table->end());
std::vector<MPIRequest> requests = it->second;
assert(requests.size() > 0);
std::ostringstream error_message_stream;
// Check that all data types being reduced or gathered are identical
auto data_type = requests[0].tensor_type();
for (unsigned int i = 1; i < requests.size(); i++) {
auto request_type = requests[i].tensor_type();
if (data_type != request_type) {
error = true;
error_message_stream << "Mismatched data types: One rank had type "
<< DataType_Name(data_type)
<< ", but another rank had type "
<< DataType_Name(request_type) << ".";
break;
}
}
// Check that all requested operations are the same
auto message_type = requests[0].request_type();
for (unsigned int i = 1; i < requests.size(); i++) {
if (error) {
break;
}
auto request_type = requests[i].request_type();
if (message_type != request_type) {
error = true;
error_message_stream << "Mismatched MPI operations: One rank did an "
<< message_type << ", but another rank did an "
<< request_type << ".";
break;
}
}
// If we are doing an allreduce, check that all tensor shapes
// are identical
if (message_type == MPIRequest::ALLREDUCE) {
TensorShape tensor_shape = requests[0].tensor_shape();
for (unsigned int i = 1; i < requests.size(); i++) {
if (error) {
break;
}
TensorShape request_shape = requests[i].tensor_shape();
if (tensor_shape != request_shape) {
error = true;
error_message_stream << "Mismatched allreduce tensor shapes: "
<< "One rank reduced a tensor of shape "
<< tensor_shape.DebugString()
<< ", but another rank sent a tensor of shape "
<< request_shape.DebugString() << ".";
break;
}
}
}
// If we are doing an allgather, make sure all but the first dimension are
// the same. The first dimension may be different and the output tensor is
// the sum of the first dimension. Collect the sizes by rank.
if (message_type == MPIRequest::ALLGATHER) {
TensorShape tensor_shape = requests[0].tensor_shape();
if (tensor_shape.dims() == 0) {
error = true;
error_message_stream << "Rank zero tried to gather a rank-zero tensor.";
}
for (unsigned int i = 1; i < requests.size(); i++) {
if (error) {
break;
}
TensorShape request_shape = requests[i].tensor_shape();
if (tensor_shape.dims() != request_shape.dims()) {
error = true;
error_message_stream << "Mismatched allgather tensor shapes: "
<< "One rank gathered a tensor of rank "
<< tensor_shape.dims()
<< ", but another rank sent a tensor of rank "
<< request_shape.dims() << ".";
break;
}
for (unsigned int dim = 1; dim < tensor_shape.dims(); dim++) {
if (tensor_shape.dim_size(dim) != request_shape.dim_size(dim)) {
error = true;
error_message_stream
<< "Mismatched allgather tensor shapes: "
<< "One rank gathered a tensor with dimension " << dim
<< " equal to " << tensor_shape.dim_size(dim)
<< ", but another rank sent a tensor with dimension " << dim
<< " equal to " << request_shape.dim_size(dim) << ".";
break;
}
}
}
}
MPIResponse response;
response.set_tensor_name(name);
if (error) {
std::string error_message = error_message_stream.str();
response.set_response_type(MPIResponse::ERROR);
response.set_error_message(error_message);
} else {
auto response_type = MPIResponse::ERROR;
if (message_type == MPIRequest::ALLREDUCE) {
response_type = MPIResponse::ALLREDUCE;
} else {
response_type = MPIResponse::ALLGATHER;
}
response.set_response_type(response_type);
}
// Clear all queued up requests for this name. They are now taken care of
// by the constructed MPI response.
message_table->erase(it);
return response;
}