public void ReduceScatter()

in MPI/Intercommunicator.cs [1351:1548]


        public void ReduceScatter<T>(T[] inValues, ReductionOperation<T> op, int[] counts, ref T[] outValues)
        {            
            // Make sure the outgoing array is the right size
            if (outValues == null || outValues.Length != counts[Rank])
                outValues = new T[counts[Rank]];

            MPI_Datatype datatype = FastDatatypeCache<T>.datatype;
            if (datatype == Unsafe.MPI_DATATYPE_NULL)
            {
                Unsafe.MPI_Status mpiStatus;
                MPI_Request mpiRequest;
                int recvCount;
                int errorCode;

                if (Rank == 0)
                {
                    // First figure out how much data we need to store
                    int totalCounts = 0;
                    for (int i = 0; i < Size; i++) checked
                    {
                        totalCounts += counts[i];
                    }

                    // Next we need to know the counts on the remote group
                    int[] remoteCounts = new int[RemoteSize];
                    unsafe
                    {
                        IntPtr inPtr = Marshal.UnsafeAddrOfPinnedArrayElement(counts, 0);
                        IntPtr outPtr = Marshal.UnsafeAddrOfPinnedArrayElement(remoteCounts, 0);
                        errorCode = Unsafe.MPI_Sendrecv(inPtr, counts.Length, Unsafe.MPI_INT, 0, collectiveTag,
                                                        outPtr, remoteCounts.Length, Unsafe.MPI_INT, 0, collectiveTag, shadowComm, out mpiStatus);
                    }

                    T[][] values = new T[RemoteSize][]; // for holding received values
                    T[] accValues = new T[totalCounts]; // for holding accumulated values
                    T[] remoteAccValues; // for holding received accumulated values
                    for (int i = 0; i < RemoteSize; i++)
                        values[i] = null;
                    T[][] sendValues = new T[RemoteSize][]; // for holding values to send; rearranged remoteAccValues
                    for (int i = 0; i < RemoteSize; i++)
                        sendValues[i] = new T[remoteCounts[i]];

                    // Get data from other root first
                    values[0] = new T[totalCounts];
                    using (UnmanagedMemoryStream sendStream = new UnmanagedMemoryStream())
                    {

                        Serialize(sendStream, inValues);
                        unsafe
                        {
                            errorCode = Unsafe.MPI_Isend(sendStream.Buffer, Convert.ToInt32(sendStream.Length), Unsafe.MPI_BYTE, 0, collectiveTag, shadowComm, out mpiRequest);
                            if (errorCode != Unsafe.MPI_SUCCESS)
                                throw Environment.TranslateErrorIntoException(errorCode);
                        }
                    }
                    Unsafe.MPI_Probe(0, collectiveTag, shadowComm, out mpiStatus);
                    errorCode = Unsafe.MPI_Get_count(ref mpiStatus, Unsafe.MPI_BYTE, out recvCount);
                    if (errorCode != Unsafe.MPI_SUCCESS)
                        throw Environment.TranslateErrorIntoException(errorCode);
                    using (UnmanagedMemoryStream recvStream = new UnmanagedMemoryStream(recvCount))
                    {
                        unsafe
                        {
                            errorCode = Unsafe.MPI_Recv(recvStream.Buffer, recvCount, Unsafe.MPI_BYTE, 0, collectiveTag, shadowComm, out mpiStatus);
                        }
                        values[0] = Deserialize<T[]>(recvStream);
                        Unsafe.MPI_Wait(ref mpiRequest, out mpiStatus);
                    }
                    for (int j = 0; j < totalCounts; j++)
                        accValues[j] = op(accValues[j], values[0][j]);

                    // Now get data from all other remote processes
                    // Interleave receives and reduction, so we don't wait
                    for (int i = 1; i < RemoteSize; i++)
                    {
                        values[i] = new T[totalCounts];
                        Unsafe.MPI_Probe(i, collectiveTag, shadowComm, out mpiStatus);
                        errorCode = Unsafe.MPI_Get_count(ref mpiStatus, Unsafe.MPI_BYTE, out recvCount);
                        if (errorCode != Unsafe.MPI_SUCCESS)
                            throw Environment.TranslateErrorIntoException(errorCode);
                        using (UnmanagedMemoryStream recvStream = new UnmanagedMemoryStream(recvCount))
                        {
                            unsafe
                            {
                                errorCode = Unsafe.MPI_Recv(recvStream.Buffer, recvCount, Unsafe.MPI_BYTE, i, collectiveTag, shadowComm, out mpiStatus);
                            }
                            values[i] = Deserialize<T[]>(recvStream);
                        }
                        Unsafe.MPI_Wait(ref mpiRequest, out mpiStatus);

                        for (int j = 0; j < totalCounts; j++)
                            accValues[j] = op(accValues[j], values[i][j]);
                    }

                    // Now we need to exhange data with the other root, 
                    // so that it can send this data to our group
                    remoteAccValues = new T[totalCounts]; // totalCounts should be same as totaled remoteCounts - SHOULD
                    using (UnmanagedMemoryStream sendStream = new UnmanagedMemoryStream())
                    {

                        Serialize(sendStream, accValues);
                        unsafe
                        {
                            errorCode = Unsafe.MPI_Isend(sendStream.Buffer, Convert.ToInt32(sendStream.Length), Unsafe.MPI_BYTE, 0, collectiveTag, shadowComm, out mpiRequest);
                            if (errorCode != Unsafe.MPI_SUCCESS)
                                throw Environment.TranslateErrorIntoException(errorCode);
                        }
                    }
                    Unsafe.MPI_Probe(0, collectiveTag, shadowComm, out mpiStatus);
                    errorCode = Unsafe.MPI_Get_count(ref mpiStatus, Unsafe.MPI_BYTE, out recvCount);
                    if (errorCode != Unsafe.MPI_SUCCESS)
                        throw Environment.TranslateErrorIntoException(errorCode);
                    using (UnmanagedMemoryStream recvStream = new UnmanagedMemoryStream(recvCount))
                    {
                        unsafe
                        {
                            errorCode = Unsafe.MPI_Recv(recvStream.Buffer, recvCount, Unsafe.MPI_BYTE, 0, collectiveTag, shadowComm, out mpiStatus);
                        }
                        remoteAccValues = Deserialize<T[]>(recvStream);
                        Unsafe.MPI_Wait(ref mpiRequest, out mpiStatus);
                    }

                    // Rearrange remoteAccValues into a shape that's more useful for sending
                    int currentPos = 0;
                    for (int i = 0; i < remoteCounts.Length; i++) checked
                    {
                        System.Array.Copy(remoteAccValues, currentPos, sendValues[i], 0, remoteCounts[i]);
                        currentPos += remoteCounts[i];
                    }

                    // Skip sending to save on communications
                    System.Array.Copy(accValues, 0, outValues, 0, counts[0]);


                    for (int i = 1; i < RemoteSize; i++)
                    {
                        using (UnmanagedMemoryStream sendStream = new UnmanagedMemoryStream())
                        {
                            Serialize(sendStream, sendValues[i]);
                            unsafe
                            {
                                errorCode = Unsafe.MPI_Send(sendStream.Buffer, Convert.ToInt32(sendStream.Length), Unsafe.MPI_BYTE, i, collectiveTag, shadowComm);
                                if (errorCode != Unsafe.MPI_SUCCESS)
                                    throw Environment.TranslateErrorIntoException(errorCode);
                            }
                        }

                    }
                }
                else
                {
                    using (UnmanagedMemoryStream sendStream = new UnmanagedMemoryStream())
                    {
                        Serialize(sendStream, inValues);
                        unsafe
                        {
                            errorCode = Unsafe.MPI_Send(sendStream.Buffer, Convert.ToInt32(sendStream.Length), Unsafe.MPI_BYTE, 0, collectiveTag, shadowComm);
                            if (errorCode != Unsafe.MPI_SUCCESS)
                                throw Environment.TranslateErrorIntoException(errorCode);
                        }
                    }
                    Unsafe.MPI_Probe(0, collectiveTag, shadowComm, out mpiStatus);
                    errorCode = Unsafe.MPI_Get_count(ref mpiStatus, Unsafe.MPI_BYTE, out recvCount);
                    if (errorCode != Unsafe.MPI_SUCCESS)
                        throw Environment.TranslateErrorIntoException(errorCode);
                    using (UnmanagedMemoryStream recvStream = new UnmanagedMemoryStream(recvCount))
                    {
                        unsafe
                        {
                            errorCode = Unsafe.MPI_Recv(recvStream.Buffer, recvCount, Unsafe.MPI_BYTE, 0, collectiveTag, shadowComm, out mpiStatus);
                        }
                        outValues = Deserialize<T[]>(recvStream);
                    }
                }
            
            }
            else
            {
                // Use the low-level MPI reduce-scatter operation from the root
                using (Operation<T> mpiOp = new Operation<T>(op))
                {
                    GCHandle inHandle = GCHandle.Alloc(inValues, GCHandleType.Pinned);
                    GCHandle outHandle = GCHandle.Alloc(outValues, GCHandleType.Pinned);
                    int errorCode;
                    unsafe
                    {
                        errorCode = Unsafe.MPI_Reduce_scatter(Marshal.UnsafeAddrOfPinnedArrayElement(inValues, 0),
                                                                  Marshal.UnsafeAddrOfPinnedArrayElement(outValues, 0),
                                                                  counts, datatype, mpiOp.Op, comm);
                    }
                    inHandle.Free();
                    outHandle.Free();

                    if (errorCode != Unsafe.MPI_SUCCESS)
                        throw Environment.TranslateErrorIntoException(errorCode);
                }
            }
        }