src/nccl_ofi_interface_nvidia.cpp (358 lines of code) (raw):

/* * Copyright (c) 2023 Amazon.com, Inc. or its affiliates. All rights reserved. */ #include "config.h" #include "nccl_ofi.h" #include "nccl_ofi_api.h" static ncclResult_t getProperties_v9(int dev_id, ncclNetProperties_v9_t* props) { nccl_ofi_properties_t ofi_properties; ncclResult_t ret = nccl_net_ofi_get_properties(dev_id, &ofi_properties); if (ret != ncclSuccess) { return ret; } props->name = ofi_properties.name; props->pciPath = ofi_properties.pci_path; props->guid = ofi_properties.guid; props->ptrSupport = NCCL_PTR_HOST; if (ofi_properties.hmem_support) { props->ptrSupport |= NCCL_PTR_CUDA; } if (ofi_properties.dmabuf_support) { props->ptrSupport |= NCCL_PTR_DMABUF; } /** * When net-plugin returns regIsGlobal=1 to NCCL (As part of * net-plugin getProperties() API), it signals to NCCL that * registered MRs are global, in the sense that they can be * used by all communicators. In addition, it also signals to * NCCL that the net-plugin have a fast MR cache such that * calling regMr() on same buffer (address and size), will * quickly return a previously globally registered MR on same * buffer. * * When user registers a buffer with NCCL by using * ncclCommRegister() API, if net-plugin supports * regIsGlobal=1, NCCL will register the buffer globally once * (On each net device) with regMr() API. When the net * proxy-thread starts to execute a communication task on a * previously registered user buffer, it will call the * net-plugin regMr() to quickly fetch the previously globally * registered MR from the plugin managed MR cache. */ props->regIsGlobal = ofi_properties.regIsGlobal; props->speed = ofi_properties.port_speed; props->port = ofi_properties.port_number; props->latency = ofi_properties.latency; props->maxComms = ofi_properties.max_communicators; props->maxRecvs = ofi_properties.max_group_receives; props->netDeviceType = NCCL_NET_DEVICE_HOST; props->netDeviceVersion = NCCL_NET_DEVICE_INVALID_VERSION; props->vProps.ndevs = 1; props->vProps.devs[0] = dev_id; props->maxP2pBytes = ofi_properties.max_p2p_bytes; props->maxCollBytes = ofi_properties.max_coll_bytes; return ncclSuccess; } static ncclResult_t getProperties_v8(int dev_id, ncclNetProperties_v8_t* props) { nccl_ofi_properties_t ofi_properties; ncclResult_t ret = nccl_net_ofi_get_properties(dev_id, &ofi_properties); if (ret != ncclSuccess) { return ret; } props->name = ofi_properties.name; props->pciPath = ofi_properties.pci_path; props->guid = ofi_properties.guid; props->ptrSupport = NCCL_PTR_HOST; if (ofi_properties.hmem_support) { props->ptrSupport |= NCCL_PTR_CUDA; } if (ofi_properties.dmabuf_support) { props->ptrSupport |= NCCL_PTR_DMABUF; } /** * When net-plugin returns regIsGlobal=1 to NCCL (As part of * net-plugin getProperties() API), it signals to NCCL that * registered MRs are global, in the sense that they can be * used by all communicators. In addition, it also signals to * NCCL that the net-plugin have a fast MR cache such that * calling regMr() on same buffer (address and size), will * quickly return a previously globally registered MR on same * buffer. * * When user registers a buffer with NCCL by using * ncclCommRegister() API, if net-plugin supports * regIsGlobal=1, NCCL will register the buffer globally once * (On each net device) with regMr() API. When the net * proxy-thread starts to execute a communication task on a * previously registered user buffer, it will call the * net-plugin regMr() to quickly fetch the previously globally * registered MR from the plugin managed MR cache. */ props->regIsGlobal = ofi_properties.regIsGlobal; props->speed = ofi_properties.port_speed; props->port = ofi_properties.port_number; props->latency = ofi_properties.latency; props->maxComms = ofi_properties.max_communicators; props->maxRecvs = ofi_properties.max_group_receives; props->netDeviceType = NCCL_NET_DEVICE_HOST; props->netDeviceVersion = NCCL_NET_DEVICE_INVALID_VERSION; return ncclSuccess; } static ncclResult_t getProperties_v7(int dev_id, ncclNetProperties_v7_t *props) { nccl_ofi_properties_t ofi_properties; ncclResult_t ret = nccl_net_ofi_get_properties(dev_id, &ofi_properties); if (ret != ncclSuccess) { return ret; } props->name = ofi_properties.name; props->pciPath = ofi_properties.pci_path; props->guid = ofi_properties.guid; props->ptrSupport = NCCL_PTR_HOST; if (ofi_properties.hmem_support) { props->ptrSupport |= NCCL_PTR_CUDA; } if (ofi_properties.dmabuf_support) { props->ptrSupport |= NCCL_PTR_DMABUF; } props->speed = ofi_properties.port_speed; props->port = ofi_properties.port_number; props->latency = ofi_properties.latency; props->maxComms = ofi_properties.max_communicators; props->maxRecvs = ofi_properties.max_group_receives; props->netDeviceType = NCCL_NET_DEVICE_HOST; props->netDeviceVersion = NCCL_NET_DEVICE_INVALID_VERSION; return ncclSuccess; } static ncclResult_t getProperties_v5(int dev_id, ncclNetProperties_v6_t *props) { nccl_ofi_properties_t ofi_properties; ncclResult_t ret = nccl_net_ofi_get_properties(dev_id, &ofi_properties); if (ret != ncclSuccess) { return ret; } props->name = ofi_properties.name; props->pciPath = ofi_properties.pci_path; props->guid = ofi_properties.guid; props->ptrSupport = NCCL_PTR_HOST; if (ofi_properties.hmem_support) { props->ptrSupport |= NCCL_PTR_CUDA; } if (ofi_properties.dmabuf_support) { props->ptrSupport |= NCCL_PTR_DMABUF; } props->speed = ofi_properties.port_speed; props->port = ofi_properties.port_number; props->latency = ofi_properties.latency; props->maxComms = ofi_properties.max_communicators; props->maxRecvs = ofi_properties.max_group_receives;; return ncclSuccess; } static ncclResult_t getProperties_v3(int dev_id, ncclNetProperties_v4_t* props) { ncclNetProperties_v6_t props_v6; ncclResult_t ret = getProperties_v5(dev_id, &props_v6); if (ret != ncclSuccess) { return ret; } props->name = props_v6.name; props->pciPath = props_v6.pciPath; props->guid = props_v6.guid; props->ptrSupport = props_v6.ptrSupport; props->speed = props_v6.speed; props->port = props_v6.port; props->maxComms = props_v6.maxComms; return ncclSuccess; } static ncclResult_t pciPath_v2(int dev_id, char** path) { ncclNetProperties_v6_t props_v6; ncclResult_t ret = getProperties_v5(dev_id, &props_v6); if (ret != ncclSuccess) { return ret; } *path = props_v6.name; return ncclSuccess; } static ncclResult_t ptrSupport_v2(int dev_id, int *supportedTypes) { ncclNetProperties_v6_t props_v6; ncclResult_t ret = getProperties_v5(dev_id, &props_v6); if (ret != ncclSuccess) { return ret; } *supportedTypes = props_v6.ptrSupport; return ncclSuccess; } // Nvidia introduced the ability to have part of the communication driven by a // cuda kernel, which requires a version-specific device pointer be passed // through the accept/connect APIs. We don't support that interface, so we // never need to look at the third argument. Rather than pollute the api // interface, just declare these wrappers in the nvidia interface. static ncclResult_t nccl_net_ofi_connect_v7(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_v7_t** sendDevComm) { return nccl_net_ofi_connect_v5(dev, handle, sendComm); } static ncclResult_t nccl_net_ofi_connect_v8(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_v8_t** sendDevComm) { return nccl_net_ofi_connect_v5(dev, handle, sendComm); } static ncclResult_t nccl_net_ofi_connect_v9(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_v9_t** sendDevComm) { return nccl_net_ofi_connect_v5(dev, handle, sendComm); } static ncclResult_t nccl_net_ofi_accept_v7(void* listenComm, void** recvComm, ncclNetDeviceHandle_v7_t** recvDevComm) { return nccl_net_ofi_accept_v5(listenComm, recvComm); } static ncclResult_t nccl_net_ofi_accept_v8(void* listenComm, void** recvComm, ncclNetDeviceHandle_v8_t** recvDevComm) { return nccl_net_ofi_accept_v5(listenComm, recvComm); } static ncclResult_t nccl_net_ofi_accept_v9(void* listenComm, void** recvComm, ncclNetDeviceHandle_v9_t** recvDevComm) { return nccl_net_ofi_accept_v5(listenComm, recvComm); } extern "C" { NCCL_OFI_EXPORT_SYMBOL ncclNet_v2_t ncclNetPlugin_v2 = { .name = "Libfabric", .init = nccl_net_ofi_init_v2, .devices = nccl_net_ofi_devices_v2, .pciPath = pciPath_v2, .ptrSupport = ptrSupport_v2, .listen = nccl_net_ofi_listen_v2, .connect = nccl_net_ofi_connect_v2, .accept = nccl_net_ofi_accept_v2, .regMr = nccl_net_ofi_regMr_v2, .deregMr = nccl_net_ofi_deregMr_v2, .isend = nccl_net_ofi_isend_v2, .irecv = nccl_net_ofi_irecv_v2, .flush = nccl_net_ofi_flush_v2, .test = nccl_net_ofi_test_v2, .closeSend = nccl_net_ofi_closeSend_v2, .closeRecv = nccl_net_ofi_closeRecv_v2, .closeListen = nccl_net_ofi_closeListen_v2, }; NCCL_OFI_EXPORT_SYMBOL ncclNet_v3_t ncclNetPlugin_v3 = { .name = "Libfabric", .init = nccl_net_ofi_init_v2, .devices = nccl_net_ofi_devices_v2, .getProperties = getProperties_v3, .listen = nccl_net_ofi_listen_v2, .connect = nccl_net_ofi_connect_v2, .accept = nccl_net_ofi_accept_v2, .regMr = nccl_net_ofi_regMr_v2, .deregMr = nccl_net_ofi_deregMr_v2, .isend = nccl_net_ofi_isend_v2, .irecv = nccl_net_ofi_irecv_v2, .flush = nccl_net_ofi_flush_v2, .test = nccl_net_ofi_test_v2, .closeSend = nccl_net_ofi_closeSend_v2, .closeRecv = nccl_net_ofi_closeRecv_v2, .closeListen = nccl_net_ofi_closeListen_v2, }; NCCL_OFI_EXPORT_SYMBOL ncclNet_v4_t ncclNetPlugin_v4 = { .name = "Libfabric", .init = nccl_net_ofi_init_v2, .devices = nccl_net_ofi_devices_v2, .getProperties = getProperties_v3, .listen = nccl_net_ofi_listen_v2, .connect = nccl_net_ofi_connect_v2, .accept = nccl_net_ofi_accept_v2, .regMr = nccl_net_ofi_regMr_v2, .deregMr = nccl_net_ofi_deregMr_v2, .isend = nccl_net_ofi_isend_v2, .irecv = nccl_net_ofi_irecv_v2, .iflush = nccl_net_ofi_iflush_v4, .test = nccl_net_ofi_test_v2, .closeSend = nccl_net_ofi_closeSend_v2, .closeRecv = nccl_net_ofi_closeRecv_v2, .closeListen = nccl_net_ofi_closeListen_v2, }; NCCL_OFI_EXPORT_SYMBOL ncclNet_v5_t ncclNetPlugin_v5 = { .name = "Libfabric", .init = nccl_net_ofi_init_v2, .devices = nccl_net_ofi_devices_v2, .getProperties = getProperties_v5, .listen = nccl_net_ofi_listen_v5, .connect = nccl_net_ofi_connect_v5, .accept = nccl_net_ofi_accept_v5, .regMr = nccl_net_ofi_regMr_v2, .deregMr = nccl_net_ofi_deregMr_v2, .isend = nccl_net_ofi_isend_v5, .irecv = nccl_net_ofi_irecv_v5, .iflush = nccl_net_ofi_iflush_v5, .test = nccl_net_ofi_test_v2, .closeSend = nccl_net_ofi_closeSend_v2, .closeRecv = nccl_net_ofi_closeRecv_v2, .closeListen = nccl_net_ofi_closeListen_v2, }; NCCL_OFI_EXPORT_SYMBOL ncclNet_v6_t ncclNetPlugin_v6 = { .name = "Libfabric", .init = nccl_net_ofi_init_v2, .devices = nccl_net_ofi_devices_v2, .getProperties = getProperties_v5, .listen = nccl_net_ofi_listen_v5, .connect = nccl_net_ofi_connect_v5, .accept = nccl_net_ofi_accept_v5, .regMr = nccl_net_ofi_regMr_v2, .regMrDmaBuf = nccl_net_ofi_regMrDmaBuf_v6, .deregMr = nccl_net_ofi_deregMr_v2, .isend = nccl_net_ofi_isend_v5, .irecv = nccl_net_ofi_irecv_v5, .iflush = nccl_net_ofi_iflush_v5, .test = nccl_net_ofi_test_v2, .closeSend = nccl_net_ofi_closeSend_v2, .closeRecv = nccl_net_ofi_closeRecv_v2, .closeListen = nccl_net_ofi_closeListen_v2, }; NCCL_OFI_EXPORT_SYMBOL ncclNet_v7_t ncclNetPlugin_v7 = { .name = "Libfabric", .init = nccl_net_ofi_init_v2, .devices = nccl_net_ofi_devices_v2, .getProperties = getProperties_v7, .listen = nccl_net_ofi_listen_v5, .connect = nccl_net_ofi_connect_v7, .accept = nccl_net_ofi_accept_v7, .regMr = nccl_net_ofi_regMr_v2, .regMrDmaBuf = nccl_net_ofi_regMrDmaBuf_v6, .deregMr = nccl_net_ofi_deregMr_v2, .isend = nccl_net_ofi_isend_v5, .irecv = nccl_net_ofi_irecv_v5, .iflush = nccl_net_ofi_iflush_v5, .test = nccl_net_ofi_test_v2, .closeSend = nccl_net_ofi_closeSend_v2, .closeRecv = nccl_net_ofi_closeRecv_v2, .closeListen = nccl_net_ofi_closeListen_v2, .getDeviceMr = NULL, .irecvConsumed = NULL, }; NCCL_OFI_EXPORT_SYMBOL ncclNet_v8_t ncclNetPlugin_v8 = { .name = "Libfabric", .init = nccl_net_ofi_init_v2, .devices = nccl_net_ofi_devices_v2, .getProperties = getProperties_v8, .listen = nccl_net_ofi_listen_v5, .connect = nccl_net_ofi_connect_v8, .accept = nccl_net_ofi_accept_v8, .regMr = nccl_net_ofi_regMr_v8, .regMrDmaBuf = nccl_net_ofi_regMrDmaBuf_v6, .deregMr = nccl_net_ofi_deregMr_v2, .isend = nccl_net_ofi_isend_v5, .irecv = nccl_net_ofi_irecv_v5, .iflush = nccl_net_ofi_iflush_v5, .test = nccl_net_ofi_test_v2, .closeSend = nccl_net_ofi_closeSend_v2, .closeRecv = nccl_net_ofi_closeRecv_v2, .closeListen = nccl_net_ofi_closeListen_v2, .getDeviceMr = NULL, .irecvConsumed = NULL, }; NCCL_OFI_EXPORT_SYMBOL ncclNet_v9_t ncclNetPlugin_v9 = { .name = "Libfabric", .init = nccl_net_ofi_init_v2, .devices = nccl_net_ofi_devices_v2, .getProperties = getProperties_v9, .listen = nccl_net_ofi_listen_v5, .connect = nccl_net_ofi_connect_v9, .accept = nccl_net_ofi_accept_v9, .regMr = nccl_net_ofi_regMr_v8, .regMrDmaBuf = nccl_net_ofi_regMrDmaBuf_v6, .deregMr = nccl_net_ofi_deregMr_v2, .isend = nccl_net_ofi_isend_v9, .irecv = nccl_net_ofi_irecv_v9, .iflush = nccl_net_ofi_iflush_v5, .test = nccl_net_ofi_test_v2, .closeSend = nccl_net_ofi_closeSend_v2, .closeRecv = nccl_net_ofi_closeRecv_v2, .closeListen = nccl_net_ofi_closeListen_v2, .getDeviceMr = NULL, .irecvConsumed = NULL, .makeVDevice = NULL, }; } /* extern "C" */ /* * Versions 1.11.0 and prior of the plugin set the name to * "AWS Libfabric", requiring NCCL_NET be set to "AWS Libfabric", * opening the door to shell escape failures. Customers do have * NCCL_NET="AWS Libfabric" in their various scripts, so still support * that. And, since we're here, also deal with the constant * "Libfabric" vs. "OFI" confusion. */ __attribute__((constructor)) static void nvidia_plugin_name_fixup(void) { char *net_env = getenv("NCCL_NET"); if (net_env != NULL && 0 == strcasecmp(net_env, "AWS Libfabric")) { ncclNetPlugin_v2.name = "AWS Libfabric"; ncclNetPlugin_v3.name = "AWS Libfabric"; ncclNetPlugin_v4.name = "AWS Libfabric"; ncclNetPlugin_v5.name = "AWS Libfabric"; ncclNetPlugin_v6.name = "AWS Libfabric"; ncclNetPlugin_v7.name = "AWS Libfabric"; ncclNetPlugin_v8.name = "AWS Libfabric"; ncclNetPlugin_v9.name = "AWS Libfabric"; } else if (net_env != NULL && 0 == strcasecmp(net_env, "OFI")) { ncclNetPlugin_v2.name = "OFI"; ncclNetPlugin_v3.name = "OFI"; ncclNetPlugin_v4.name = "OFI"; ncclNetPlugin_v5.name = "OFI"; ncclNetPlugin_v6.name = "OFI"; ncclNetPlugin_v7.name = "OFI"; ncclNetPlugin_v8.name = "OFI"; ncclNetPlugin_v9.name = "OFI"; } }