include/nccl_ofi_mr.h (154 lines of code) (raw):

/* * Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All rights reserved. */ #ifndef NCCL_OFI_MR_H_ #define NCCL_OFI_MR_H_ #include "config.h" #include <assert.h> #include <pthread.h> #include <stddef.h> #include <stdint.h> #include <sys/uio.h> #include <rdma/fi_domain.h> #include "nccl_ofi_math.h" #include "nccl_ofi_log.h" #define NCCL_OFI_CACHE_PAGE_SIZE (4096ul) enum nccl_ofi_mr_ckey_type { NCCL_OFI_MR_CKEY_INVALID = 0, NCCL_OFI_MR_CKEY_IOVEC, #if HAVE_DECL_FI_MR_DMABUF NCCL_OFI_MR_CKEY_DMABUF, #endif }; typedef enum nccl_ofi_mr_ckey_type nccl_ofi_mr_ckey_type_t; struct nccl_ofi_mr_ckey { union { struct iovec iovec; #if HAVE_DECL_FI_MR_DMABUF struct fi_mr_dmabuf fi_mr_dmabuf; #endif }; enum nccl_ofi_mr_ckey_type type; }; typedef struct nccl_ofi_mr_ckey nccl_ofi_mr_ckey_t; typedef struct nccl_ofi_mr_ckey const *const nccl_ofi_mr_ckey_ref; static_assert(offsetof(struct nccl_ofi_mr_ckey, iovec) == 0, "Cache keys must be safe to cast to 'struct iovec'"); #if HAVE_DECL_FI_MR_DMABUF static_assert(offsetof(struct nccl_ofi_mr_ckey, fi_mr_dmabuf) == 0, "Cache keys must be safe to cast to 'struct fi_mr_dmabuf'"); #endif /* Alignement of MR cache and key creation */ extern size_t mr_cache_alignment; static inline const char *nccl_ofi_mr_ckey_type_str(nccl_ofi_mr_ckey_ref ckey) { switch (ckey->type) { case NCCL_OFI_MR_CKEY_IOVEC: return "iovec"; #if HAVE_DECL_FI_MR_DMABUF case NCCL_OFI_MR_CKEY_DMABUF: return "dmabuf"; #endif case NCCL_OFI_MR_CKEY_INVALID: default: __builtin_unreachable(); assert(false); return ""; } } static inline uintptr_t nccl_ofi_mr_ckey_baseaddr(nccl_ofi_mr_ckey_ref ckey) { switch (ckey->type) { case NCCL_OFI_MR_CKEY_IOVEC: return (uintptr_t)ckey->iovec.iov_base; #if HAVE_DECL_FI_MR_DMABUF case NCCL_OFI_MR_CKEY_DMABUF: return (uintptr_t)ckey->fi_mr_dmabuf.base_addr + ckey->fi_mr_dmabuf.offset; #endif case NCCL_OFI_MR_CKEY_INVALID: default: __builtin_unreachable(); assert(false); return 0; } } static inline uintptr_t nccl_ofi_mr_ckey_len(nccl_ofi_mr_ckey_ref ckey) { switch (ckey->type) { case NCCL_OFI_MR_CKEY_IOVEC: return ckey->iovec.iov_len; #if HAVE_DECL_FI_MR_DMABUF case NCCL_OFI_MR_CKEY_DMABUF: return ckey->fi_mr_dmabuf.len; #endif case NCCL_OFI_MR_CKEY_INVALID: default: __builtin_unreachable(); assert(false); return 0; } } static inline void nccl_ofi_mr_ckey_round(size_t *len, void **base_addr, const char *type) { uintptr_t page_base = NCCL_OFI_ROUND_DOWN((uintptr_t)*base_addr, mr_cache_alignment); size_t aligned_size = NCCL_OFI_ROUND_UP(((uintptr_t)*base_addr + *len), mr_cache_alignment) - page_base; NCCL_OFI_TRACE_WHEN(((uintptr_t)*base_addr != page_base || aligned_size != *len), NCCL_NET, "Going to register mr %s %p size %ld as %p size %ld", type, *base_addr, *len, (void *)page_base, aligned_size); *base_addr = (void *)page_base; *len = aligned_size; } #if HAVE_DECL_FI_MR_DMABUF static inline nccl_ofi_mr_ckey_t nccl_ofi_mr_ckey_mk_dmabuf(int fd, uint64_t offset, size_t len, void *base_addr) { nccl_ofi_mr_ckey_t cache_key = {}; cache_key.fi_mr_dmabuf.fd = fd; cache_key.fi_mr_dmabuf.offset = offset; cache_key.fi_mr_dmabuf.len = len; cache_key.fi_mr_dmabuf.base_addr = base_addr; cache_key.type = NCCL_OFI_MR_CKEY_DMABUF; return cache_key; } #endif static inline nccl_ofi_mr_ckey_t nccl_ofi_mr_ckey_mk_vec(void *iov_base, size_t iov_len) { nccl_ofi_mr_ckey_round(&iov_len, &iov_base, "iovec"); nccl_ofi_mr_ckey_t cache_key = {}; cache_key.iovec.iov_base = iov_base; cache_key.iovec.iov_len = iov_len; cache_key.type = NCCL_OFI_MR_CKEY_IOVEC; return cache_key; } static inline void nccl_ofi_mr_ckey_fill_mr_attrs(nccl_ofi_mr_ckey_ref ckey, struct fi_mr_attr *attrs, uint64_t *flags) { assert(ckey->type != NCCL_OFI_MR_CKEY_INVALID); *flags = 0; #if HAVE_DECL_FI_MR_DMABUF if (ckey->type == NCCL_OFI_MR_CKEY_DMABUF) { *flags |= FI_MR_DMABUF; // note: because ckey's first member is layout-compatible with // fi_mr_attr's first member, both sides of this branch are the // same as // `memcpy(attrs, ckey, max(sizeof(struct iovec), // sizeof(struct fi_mr_dmabuf)))' attrs->dmabuf = (const struct fi_mr_dmabuf *)ckey; } else { // see comment above attrs->mr_iov = (const struct iovec *)ckey; } #else attrs->mr_iov = (const struct iovec *)ckey; #endif attrs->iov_count = 1; } /** * A memory registration cache entry */ typedef struct nccl_ofi_reg_entry { uintptr_t addr; size_t pages; int refcnt; void *handle; } nccl_ofi_reg_entry_t; /** * Device-specific memory registration cache. */ typedef struct nccl_ofi_mr_cache { nccl_ofi_reg_entry_t **slots; size_t system_page_size; size_t size; size_t used; uint32_t hit_count; uint32_t miss_count; pthread_mutex_t lock; } nccl_ofi_mr_cache_t; /** * Create a new mr cache. Both then initial number of entries and the system * page size must be greater than zero. * @return a new mr cache, or NULL if an allocation error occurred */ nccl_ofi_mr_cache_t *nccl_ofi_mr_cache_init(size_t init_num_entries, size_t mr_cache_page_size); /** * Finalize mr cache */ void nccl_ofi_mr_cache_finalize(nccl_ofi_mr_cache_t *cache); /** * Lookup a cache entry matching the given address and size * Input addr and size are rounded up to enclosing page boundaries. * If entry is found, refcnt is increased * @return mr handle if found, or NULL if not found */ void *nccl_ofi_mr_cache_lookup_entry(nccl_ofi_mr_cache_t *cache, nccl_ofi_mr_ckey_ref ckey); /** * Insert a new cache entry with the given address and size * Input addr and size are rounded up to enclosing page boundaries. * @return 0, on success * -ENOMEM, on allocation failure * -EEXIST, if matching entry already exists in cache */ int nccl_ofi_mr_cache_insert_entry(nccl_ofi_mr_cache_t *cache, nccl_ofi_mr_ckey_ref ckey, void *handle); /** * Decrement refcnt of entry with given handle. If refcnt was reduced to 0, * delete entry from cache. Return value indicates whether entry was deleted * from cache (in which case, caller should deregister the handle). * * @return 0, on success, and reg was not deleted (refcnt not zero) * 1, on success, and reg was deleted (refcnt was zero) * -ENOENT, if no matching entry was found */ int nccl_ofi_mr_cache_del_entry(nccl_ofi_mr_cache_t *cache, void *handle); #endif // End NCCL_OFI_MR_H_