include/nccl_ofi_msgbuff.h (47 lines of code) (raw):

/* * Copyright (c) 2023 Amazon.com, Inc. or its affiliates. All rights reserved. */ #ifndef NCCL_OFI_MSGBUFF_H_ #define NCCL_OFI_MSGBUFF_H_ #include <pthread.h> #include <stdint.h> /** * A "modified circular buffer" used to track in-flight (or INPROGRESS) messages. * Messages are identified by a wrapping sequence number (with bit width chosen during * initialization). The buffer maintains two pointers: msg_next and msg_last_incomplete. * - msg_next: one after inserted message with highest sequence number * - msg_last_incomplete: not-completed message with lowest sequence number * * The msgbuff features a custom number of bits used for the sequence numbers. * The space of all sequence numbers is divided in 3 contiguous, moving sections: * * 1. One section for in-flight messages, whose max size N is chosen during initialization. * Only this section has elements actually stored in the backing buffer. The max size N * of this section (and the buffer) represents the maximum number of in-flight messages * allowed, and should be smaller (less than half) than the overall range of sequence * numbers, to leave space for the other sections. * The modulus of the sequence number is used to index the backing buffer. * 2. One section for completed messages. This section has always size N and * is always preceding section 1. All the N sequence numbers preceding section 1, with * possible wraparound, are implicitly considered belonging to completed messages. * Every time the pending message with the smaller sequence number is completed, the * msg_last_incomplete pointer is incremented (possibly more than once if the following * sequence numbers also belong to messages completed out-of-order). This moves the bottom * of section 1 forward and implicitly also the bottom of section 2. * 3. All other sequence numbers are considered messages that haven't been started. * * The buffer for in-flight messages stores void* elements: the user of the buffer is * responsible for managing the memory of buffer elements. */ /* Enumeration to keep track of different msg statuses. */ typedef enum { /** The message has been marked completed **/ NCCL_OFI_MSGBUFF_COMPLETED, /** The message has been added to the buffer but not marked complete **/ NCCL_OFI_MSGBUFF_INPROGRESS, /** The message has not yet been added to the buffer **/ NCCL_OFI_MSGBUFF_NOTSTARTED, /** The index is not in the range of completed or not-started messages **/ NCCL_OFI_MSGBUFF_UNAVAILABLE, } nccl_ofi_msgbuff_status_t; typedef enum { /** Operation completed successfully **/ NCCL_OFI_MSGBUFF_SUCCESS, /** The provided index was invalid; see msg_idx_status output **/ NCCL_OFI_MSGBUFF_INVALID_IDX, /** Other error **/ NCCL_OFI_MSGBUFF_ERROR, } nccl_ofi_msgbuff_result_t; /* Type of element stored in msg buffer. This is used to distinguish between reqs and rx buffers (when we don't have req) stored in the message buffer */ typedef enum { /* Request */ NCCL_OFI_MSGBUFF_REQ, /* Rx buffer */ NCCL_OFI_MSGBUFF_BUFF } nccl_ofi_msgbuff_elemtype_t; /* Internal buffer storage type, used to keep status of elements currently stored in * buffer */ typedef struct { // Status of message: COMPLETED, INPROGRESS, or NOTSTARTED nccl_ofi_msgbuff_status_t stat; // Type of element nccl_ofi_msgbuff_elemtype_t type; void *elem; } nccl_ofi_msgbuff_elem_t; typedef struct { // Element storage buffer. Allocated in msgbuff_init nccl_ofi_msgbuff_elem_t *buff; /* Max number of INPROGRESS elements. These are the only * ones backed by the storage buffer, so this is also the * size of the storage buffer */ uint16_t max_inprogress; /* Size of the range of all possible sequence numbers, * which depends on how many bits are used for them. */ uint16_t field_size; /* Bit mask for the sequence numbers */ uint16_t field_mask; // Points to the not-finished message with the lowest sequence number uint16_t msg_last_incomplete; // Points to the message after the inserted message with highest sequence number. uint16_t msg_next; // Mutex for this msg buffer -- locks all non-init operations pthread_mutex_t lock; } nccl_ofi_msgbuff_t; /** * Allocates and initializes a new message buffer. * @param max_inprogress max number of INPROGRESS elements, which are backed by * the storage buffer * @param bit_width bit_width of the sequence numbers, which provides the range * of elements tracked by this msgbuff * * @return a new msgbuff, or NULL if initialization failed */ nccl_ofi_msgbuff_t *nccl_ofi_msgbuff_init(uint16_t max_inprogress, uint16_t bit_width); /** * Destroy a message buffer (free memory used by buffer). * * @return true if success, false if failed */ bool nccl_ofi_msgbuff_destroy(nccl_ofi_msgbuff_t *msgbuff); /** * Insert a new message element * * @param elem, pointer to store at msg_index * type, type of element * msg_idx_status, output: message status, if return value is INVALID_IDX * * @return * NCCL_OFI_MSGBUFF_SUCCESS, success * NCCL_OFI_MSGBUFF_INVALID_IDX, invalid index. See msg_idx_status. * NCCL_OFI_MSGBUFF_ERROR, other error */ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert(nccl_ofi_msgbuff_t *msgbuff, uint16_t msg_index, void *elem, nccl_ofi_msgbuff_elemtype_t type, nccl_ofi_msgbuff_status_t *msg_idx_status); /** * Replace an existing message element * * @param elem, pointer to store at msg_index * type, type of element * msg_idx_status, output: message status, if return value is INVALID_IDX * * @return * NCCL_OFI_MSGBUFF_SUCCESS, success * NCCL_OFI_MSGBUFF_INVALID_IDX, invalid index. See msg_idx_status. * NCCL_OFI_MSGBUFF_ERROR, other error */ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_replace(nccl_ofi_msgbuff_t *msgbuff, uint16_t msg_index, void *elem, nccl_ofi_msgbuff_elemtype_t type, nccl_ofi_msgbuff_status_t *msg_idx_status); /** * Retrieve message with given index * * @param elem, output: pointer to element at msg_index * type, output: type of element * msg_idx_status, output: message status, if return value is INVALID_IDX * * @return * NCCL_OFI_MSGBUFF_SUCCESS, success * NCCL_OFI_MSGBUFF_INVALID_IDX, invalid index. See msg_idx_status. * NCCL_OFI_MSGBUFF_ERROR, other error */ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_retrieve(nccl_ofi_msgbuff_t *msgbuff, uint16_t msg_index, void **elem, nccl_ofi_msgbuff_elemtype_t *type, nccl_ofi_msgbuff_status_t *msg_idx_status); /** * Mark message with given index as complete * * @param msg_idx_status, output: message status, if return value is INVALID_IDX * * @return * NCCL_OFI_MSGBUFF_SUCCESS, success * NCCL_OFI_MSGBUFF_INVALID_IDX, invalid index. See msg_idx_status. * NCCL_OFI_MSGBUFF_ERROR, other error */ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_complete(nccl_ofi_msgbuff_t *msgbuff, uint16_t msg_index, nccl_ofi_msgbuff_status_t *msg_idx_status); #endif // End NCCL_OFI_MSGBUFF_H_