in MPI/Intercommunicator.cs [1351:1548]
public void ReduceScatter<T>(T[] inValues, ReductionOperation<T> op, int[] counts, ref T[] outValues)
{
// Make sure the outgoing array is the right size
if (outValues == null || outValues.Length != counts[Rank])
outValues = new T[counts[Rank]];
MPI_Datatype datatype = FastDatatypeCache<T>.datatype;
if (datatype == Unsafe.MPI_DATATYPE_NULL)
{
Unsafe.MPI_Status mpiStatus;
MPI_Request mpiRequest;
int recvCount;
int errorCode;
if (Rank == 0)
{
// First figure out how much data we need to store
int totalCounts = 0;
for (int i = 0; i < Size; i++) checked
{
totalCounts += counts[i];
}
// Next we need to know the counts on the remote group
int[] remoteCounts = new int[RemoteSize];
unsafe
{
IntPtr inPtr = Marshal.UnsafeAddrOfPinnedArrayElement(counts, 0);
IntPtr outPtr = Marshal.UnsafeAddrOfPinnedArrayElement(remoteCounts, 0);
errorCode = Unsafe.MPI_Sendrecv(inPtr, counts.Length, Unsafe.MPI_INT, 0, collectiveTag,
outPtr, remoteCounts.Length, Unsafe.MPI_INT, 0, collectiveTag, shadowComm, out mpiStatus);
}
T[][] values = new T[RemoteSize][]; // for holding received values
T[] accValues = new T[totalCounts]; // for holding accumulated values
T[] remoteAccValues; // for holding received accumulated values
for (int i = 0; i < RemoteSize; i++)
values[i] = null;
T[][] sendValues = new T[RemoteSize][]; // for holding values to send; rearranged remoteAccValues
for (int i = 0; i < RemoteSize; i++)
sendValues[i] = new T[remoteCounts[i]];
// Get data from other root first
values[0] = new T[totalCounts];
using (UnmanagedMemoryStream sendStream = new UnmanagedMemoryStream())
{
Serialize(sendStream, inValues);
unsafe
{
errorCode = Unsafe.MPI_Isend(sendStream.Buffer, Convert.ToInt32(sendStream.Length), Unsafe.MPI_BYTE, 0, collectiveTag, shadowComm, out mpiRequest);
if (errorCode != Unsafe.MPI_SUCCESS)
throw Environment.TranslateErrorIntoException(errorCode);
}
}
Unsafe.MPI_Probe(0, collectiveTag, shadowComm, out mpiStatus);
errorCode = Unsafe.MPI_Get_count(ref mpiStatus, Unsafe.MPI_BYTE, out recvCount);
if (errorCode != Unsafe.MPI_SUCCESS)
throw Environment.TranslateErrorIntoException(errorCode);
using (UnmanagedMemoryStream recvStream = new UnmanagedMemoryStream(recvCount))
{
unsafe
{
errorCode = Unsafe.MPI_Recv(recvStream.Buffer, recvCount, Unsafe.MPI_BYTE, 0, collectiveTag, shadowComm, out mpiStatus);
}
values[0] = Deserialize<T[]>(recvStream);
Unsafe.MPI_Wait(ref mpiRequest, out mpiStatus);
}
for (int j = 0; j < totalCounts; j++)
accValues[j] = op(accValues[j], values[0][j]);
// Now get data from all other remote processes
// Interleave receives and reduction, so we don't wait
for (int i = 1; i < RemoteSize; i++)
{
values[i] = new T[totalCounts];
Unsafe.MPI_Probe(i, collectiveTag, shadowComm, out mpiStatus);
errorCode = Unsafe.MPI_Get_count(ref mpiStatus, Unsafe.MPI_BYTE, out recvCount);
if (errorCode != Unsafe.MPI_SUCCESS)
throw Environment.TranslateErrorIntoException(errorCode);
using (UnmanagedMemoryStream recvStream = new UnmanagedMemoryStream(recvCount))
{
unsafe
{
errorCode = Unsafe.MPI_Recv(recvStream.Buffer, recvCount, Unsafe.MPI_BYTE, i, collectiveTag, shadowComm, out mpiStatus);
}
values[i] = Deserialize<T[]>(recvStream);
}
Unsafe.MPI_Wait(ref mpiRequest, out mpiStatus);
for (int j = 0; j < totalCounts; j++)
accValues[j] = op(accValues[j], values[i][j]);
}
// Now we need to exhange data with the other root,
// so that it can send this data to our group
remoteAccValues = new T[totalCounts]; // totalCounts should be same as totaled remoteCounts - SHOULD
using (UnmanagedMemoryStream sendStream = new UnmanagedMemoryStream())
{
Serialize(sendStream, accValues);
unsafe
{
errorCode = Unsafe.MPI_Isend(sendStream.Buffer, Convert.ToInt32(sendStream.Length), Unsafe.MPI_BYTE, 0, collectiveTag, shadowComm, out mpiRequest);
if (errorCode != Unsafe.MPI_SUCCESS)
throw Environment.TranslateErrorIntoException(errorCode);
}
}
Unsafe.MPI_Probe(0, collectiveTag, shadowComm, out mpiStatus);
errorCode = Unsafe.MPI_Get_count(ref mpiStatus, Unsafe.MPI_BYTE, out recvCount);
if (errorCode != Unsafe.MPI_SUCCESS)
throw Environment.TranslateErrorIntoException(errorCode);
using (UnmanagedMemoryStream recvStream = new UnmanagedMemoryStream(recvCount))
{
unsafe
{
errorCode = Unsafe.MPI_Recv(recvStream.Buffer, recvCount, Unsafe.MPI_BYTE, 0, collectiveTag, shadowComm, out mpiStatus);
}
remoteAccValues = Deserialize<T[]>(recvStream);
Unsafe.MPI_Wait(ref mpiRequest, out mpiStatus);
}
// Rearrange remoteAccValues into a shape that's more useful for sending
int currentPos = 0;
for (int i = 0; i < remoteCounts.Length; i++) checked
{
System.Array.Copy(remoteAccValues, currentPos, sendValues[i], 0, remoteCounts[i]);
currentPos += remoteCounts[i];
}
// Skip sending to save on communications
System.Array.Copy(accValues, 0, outValues, 0, counts[0]);
for (int i = 1; i < RemoteSize; i++)
{
using (UnmanagedMemoryStream sendStream = new UnmanagedMemoryStream())
{
Serialize(sendStream, sendValues[i]);
unsafe
{
errorCode = Unsafe.MPI_Send(sendStream.Buffer, Convert.ToInt32(sendStream.Length), Unsafe.MPI_BYTE, i, collectiveTag, shadowComm);
if (errorCode != Unsafe.MPI_SUCCESS)
throw Environment.TranslateErrorIntoException(errorCode);
}
}
}
}
else
{
using (UnmanagedMemoryStream sendStream = new UnmanagedMemoryStream())
{
Serialize(sendStream, inValues);
unsafe
{
errorCode = Unsafe.MPI_Send(sendStream.Buffer, Convert.ToInt32(sendStream.Length), Unsafe.MPI_BYTE, 0, collectiveTag, shadowComm);
if (errorCode != Unsafe.MPI_SUCCESS)
throw Environment.TranslateErrorIntoException(errorCode);
}
}
Unsafe.MPI_Probe(0, collectiveTag, shadowComm, out mpiStatus);
errorCode = Unsafe.MPI_Get_count(ref mpiStatus, Unsafe.MPI_BYTE, out recvCount);
if (errorCode != Unsafe.MPI_SUCCESS)
throw Environment.TranslateErrorIntoException(errorCode);
using (UnmanagedMemoryStream recvStream = new UnmanagedMemoryStream(recvCount))
{
unsafe
{
errorCode = Unsafe.MPI_Recv(recvStream.Buffer, recvCount, Unsafe.MPI_BYTE, 0, collectiveTag, shadowComm, out mpiStatus);
}
outValues = Deserialize<T[]>(recvStream);
}
}
}
else
{
// Use the low-level MPI reduce-scatter operation from the root
using (Operation<T> mpiOp = new Operation<T>(op))
{
GCHandle inHandle = GCHandle.Alloc(inValues, GCHandleType.Pinned);
GCHandle outHandle = GCHandle.Alloc(outValues, GCHandleType.Pinned);
int errorCode;
unsafe
{
errorCode = Unsafe.MPI_Reduce_scatter(Marshal.UnsafeAddrOfPinnedArrayElement(inValues, 0),
Marshal.UnsafeAddrOfPinnedArrayElement(outValues, 0),
counts, datatype, mpiOp.Op, comm);
}
inHandle.Free();
outHandle.Free();
if (errorCode != Unsafe.MPI_SUCCESS)
throw Environment.TranslateErrorIntoException(errorCode);
}
}
}