include/tracing_impl/nvtx.h (200 lines of code) (raw):

/* * Copyright (c) 2022-2024 Amazon.com, Inc. or its affiliates. All rights reserved. */ #ifndef NVTX_H #define NVTX_H #if HAVE_NVTX_TRACING #include <nvtx3/nvToolsExt.h> static inline void nvtx_mark_domain(nvtxDomainHandle_t domain, const char* name, uint32_t color) { nvtxEventAttributes_t eventAttrib = {}; eventAttrib.version = NVTX_VERSION; eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; eventAttrib.colorType = NVTX_COLOR_ARGB; eventAttrib.color = color; eventAttrib.messageType = NVTX_MESSAGE_TYPE_ASCII; eventAttrib.message.ascii = name; nvtxDomainMarkEx(domain, &eventAttrib); } static inline nvtxRangeId_t nvtx_start_domain(bool have_domain, nvtxDomainHandle_t domain, const char* name, uint32_t color) { nvtxEventAttributes_t eventAttrib = {}; eventAttrib.version = NVTX_VERSION; eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; eventAttrib.colorType = NVTX_COLOR_ARGB; eventAttrib.color = color; eventAttrib.messageType = NVTX_MESSAGE_TYPE_ASCII; eventAttrib.message.ascii = name; if (have_domain) return nvtxDomainRangeStartEx(domain, &eventAttrib); else return nvtxRangeStartEx(&eventAttrib); } static inline nvtxRangeId_t nvtx_start(const char* name, uint32_t color) { return nvtx_start_domain(false, 0, name, color); } static inline void nvtx_end_domain(nvtxDomainHandle_t domain, nvtxRangeId_t id) { nvtxDomainRangeEnd(domain, id); } static inline void nvtx_end(nvtxRangeId_t id) { nvtxRangeEnd(id); } #define NCCL_OFI_TRACE_SEND_NVTX(dev, size, comm, msg_seq_num, request, nccl_req) do { \ if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \ nvtxDomainHandle_t handle = ((nccl_net_ofi_rdma_send_comm_t*)comm) \ ->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \ get_send_data(request)->trace_id = nvtx_start_domain(true, handle, "Send", 0xeb9234); \ } \ } while (0) #define NCCL_OFI_TRACE_SEND_END_NVTX(request) do { \ if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \ nvtxDomainHandle_t handle = ((nccl_net_ofi_rdma_send_comm_t*)(request->comm)) \ ->nvtx_domain[request->msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \ nvtx_end_domain(handle, get_send_data(request)->trace_id); \ } \ } while(0) #define NCCL_OFI_TRACE_EAGER_SEND_START_NVTX(dev, rail_id, size, comm, msg_seq_num, request) do { \ nvtxDomainHandle_t handle; \ if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \ handle = ((nccl_net_ofi_rdma_send_comm_t*)comm)->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \ get_send_data(request)->seg_trace_id[rail_id] = nvtx_start_domain(true, handle, "Send_eager", 0x0000FF); \ } \ if (NCCL_OFI_NVTX_TRACE_PER_DEV) { \ handle = (rdma_endpoint_get_device((nccl_net_ofi_rdma_ep_t *)comm->ep))->nvtx_domain[rail_id]; \ get_send_data(request)->seg_trace_id[rail_id] = nvtx_start_domain(true, handle, "Send_eager", 0x0000FF); \ } \ } while (0) #define NCCL_OFI_TRACE_EAGER_SEND_COMPLETE_NVTX(dev, rail_id, comm, msg_seq_num, request) do { \ nvtxDomainHandle_t handle; \ if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \ handle = ((nccl_net_ofi_rdma_send_comm_t*)comm)->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \ nvtx_end_domain(handle, get_send_data(request)->seg_trace_id[rail_id]); \ } \ if (NCCL_OFI_NVTX_TRACE_PER_DEV) { \ handle = (rdma_endpoint_get_device((nccl_net_ofi_rdma_ep_t *)comm->ep))->nvtx_domain[rail_id]; \ nvtx_end_domain(handle, get_send_data(request)->seg_trace_id[rail_id]); \ } \ } while(0) #define NCCL_OFI_TRACE_SEND_CTRL_RECV_NVTX(dev, rail_id, comm, msg_seq_num) do { \ nvtxDomainHandle_t handle; \ if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \ handle = ((nccl_net_ofi_rdma_send_comm_t*)comm)->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \ nvtx_mark_domain(handle, "Send_ctrl_recv", 0x00ffff); \ } \ if (NCCL_OFI_NVTX_TRACE_PER_DEV) { \ handle = (rdma_endpoint_get_device((nccl_net_ofi_rdma_ep_t *)s_comm->base.base.ep))->nvtx_domain[rail_id]; \ nvtx_mark_domain(handle, "Send_ctrl_recv", 0x00ffff); \ } \ } while (0) #define NCCL_OFI_TRACE_SEND_CTRL_START_NVTX(dev, rail_id, comm, req, msg_seq_num) do { \ nvtxDomainHandle_t handle; \ if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \ handle = ((nccl_net_ofi_rdma_recv_comm_t *)comm)->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \ get_send_ctrl_data(req)->trace_id = nvtx_start_domain(true, handle, "Send_ctrl_start", 0x00ffff); \ } \ if (NCCL_OFI_NVTX_TRACE_PER_DEV) { \ handle = (rdma_endpoint_get_device((nccl_net_ofi_rdma_ep_t *)comm->ep))->nvtx_domain[rail_id]; \ get_send_ctrl_data(req)->trace_id = nvtx_start_domain(true, handle, "Send_ctrl_start", 0x00ffff); \ } \ } while (0) #define NCCL_OFI_TRACE_SEND_CTRL_END_NVTX(dev, rail_id, comm, req, msg_seq_num) do { \ nvtxDomainHandle_t handle; \ if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \ handle = ((nccl_net_ofi_rdma_recv_comm_t *)comm)->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \ nvtx_end_domain(handle, get_send_ctrl_data(req)->trace_id); \ } \ if (NCCL_OFI_NVTX_TRACE_PER_DEV) { \ handle = (rdma_endpoint_get_device((nccl_net_ofi_rdma_ep_t *)comm->ep))->nvtx_domain[rail_id]; \ nvtx_end_domain(handle, get_send_ctrl_data(req)->trace_id);\ } \ } while (0) #define NCCL_OFI_TRACE_SEND_WRITE_SEG_START_NVTX(dev, rail_id, size, comm, msg_seq_num, request) do { \ nvtxDomainHandle_t handle; \ if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \ handle = ((nccl_net_ofi_rdma_send_comm_t*)comm)->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \ get_send_data(request)->seg_trace_id[rail_id] = nvtx_start_domain(true, handle, "Send_write_seg", 0xff0000); \ } \ if (NCCL_OFI_NVTX_TRACE_PER_DEV) { \ handle = (rdma_endpoint_get_device((nccl_net_ofi_rdma_ep_t *)comm->ep))->nvtx_domain[rail_id]; \ get_send_data(request)->seg_trace_id[rail_id] = nvtx_start_domain(true, handle, "Send_write_seg", 0xff0000); \ } \ } while(0) #define NCCL_OFI_TRACE_SEND_WRITE_SEG_COMPLETE_NVTX(dev, rail_id, comm, msg_seq_num, request) do { \ nvtxDomainHandle_t handle; \ if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \ handle = ((nccl_net_ofi_rdma_send_comm_t*)comm)->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \ nvtx_end_domain(handle, get_send_data(request)->seg_trace_id[rail_id]); \ } \ if (NCCL_OFI_NVTX_TRACE_PER_DEV) { \ handle = (rdma_endpoint_get_device((nccl_net_ofi_rdma_ep_t *)comm->ep))->nvtx_domain[rail_id]; \ nvtx_end_domain(handle, get_send_data(request)->seg_trace_id[rail_id]); \ } \ } while(0) #define NCCL_OFI_TRACE_RECV_NVTX(dev, comm, size, request, nccl_req) do { \ if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \ nvtxDomainHandle_t handle = ((nccl_net_ofi_rdma_recv_comm_t *)request->comm) \ ->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \ get_recv_data(request)->trace_id = nvtx_start_domain(true, handle, "Recv", 0x34EB37); \ } \ } while(0) #define NCCL_OFI_TRACE_RECV_END_NVTX(request) do { \ if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \ nvtxDomainHandle_t handle = ((nccl_net_ofi_rdma_recv_comm_t *)request->comm) \ ->nvtx_domain[request->msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \ nvtx_end_domain(handle, get_recv_data(request)->trace_id); \ } \ } while(0) #define NCCL_OFI_TRACE_RECV_SEGMENT_COMPLETE_NVTX(dev, rail_id, size, request, msg_seq_num) do { \ nvtxDomainHandle_t handle; \ if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \ handle = ((nccl_net_ofi_rdma_recv_comm_t *)request->comm)->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \ nvtx_mark_domain(handle, "Recv_segment_complete", 0xff0000); \ } \ if (NCCL_OFI_NVTX_TRACE_PER_DEV) { \ handle = (rdma_endpoint_get_device((nccl_net_ofi_rdma_ep_t *)request->comm->ep))->nvtx_domain[rail_id]; \ nvtx_mark_domain(handle, "Recv_segment_complete", 0xff0000); \ } \ } while(0) #define NCCL_OFI_TRACE_EAGER_RECV_NVTX(dev, rail_id, comm, msg_seq_num) do { \ nvtxDomainHandle_t handle; \ if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \ handle = comm->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \ nvtx_mark_domain(handle, "Eager_recv", 0x0000FF); \ } \ if (NCCL_OFI_NVTX_TRACE_PER_DEV) { \ handle = (rdma_endpoint_get_device((nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep))->nvtx_domain[rail_id]; \ nvtx_mark_domain(handle, "Eager_recv", 0x0000FF); \ } \ } while(0) #define NCCL_OFI_TRACE_FLUSH_NVTX(request, nccl_req) do { \ nvtx_mark_domain(NULL, "Flush", 0xA52A2A); \ } while(0) #define NCCL_OFI_TRACE_READ_NVTX(request, nccl_req) do { \ nvtx_mark_domain(NULL, "Read", 0xff00ff); \ } while(0) #define NCCL_OFI_TRACE_WRITE_NVTX(request, nccl_req) do { \ nvtx_mark_domain(NULL, "Write", 0xff00ff); \ } while(0) #define NCCL_OFI_TRACE_PENDING_INSERT_NVTX(request) do { \ nvtx_mark_domain(NULL, "Pending_insert", 0xFF8C00); \ } while(0) #define NCCL_OFI_TRACE_PENDING_REMOVE_NVTX(request) do { \ nvtx_mark_domain(NULL, "Pending_remove", 0xFF8C00); \ } while(0) #else #define NCCL_OFI_TRACE_SEND_NVTX(...) #define NCCL_OFI_TRACE_SEND_END_NVTX(...) #define NCCL_OFI_TRACE_EAGER_SEND_START_NVTX(...) #define NCCL_OFI_TRACE_EAGER_SEND_COMPLETE_NVTX(...) #define NCCL_OFI_TRACE_SEND_CTRL_RECV_NVTX(...) #define NCCL_OFI_TRACE_SEND_CTRL_START_NVTX(...) #define NCCL_OFI_TRACE_SEND_CTRL_END_NVTX(...) #define NCCL_OFI_TRACE_SEND_WRITE_SEG_START_NVTX(...) #define NCCL_OFI_TRACE_SEND_WRITE_SEG_COMPLETE_NVTX(...) #define NCCL_OFI_TRACE_RECV_NVTX(...) #define NCCL_OFI_TRACE_RECV_END_NVTX(...) #define NCCL_OFI_TRACE_RECV_SEGMENT_COMPLETE_NVTX(...) #define NCCL_OFI_TRACE_EAGER_RECV_NVTX(...) #define NCCL_OFI_TRACE_FLUSH_NVTX(...) #define NCCL_OFI_TRACE_READ_NVTX(...) #define NCCL_OFI_TRACE_WRITE_NVTX(...) #define NCCL_OFI_TRACE_PENDING_INSERT_NVTX(...) #define NCCL_OFI_TRACE_PENDING_REMOVE_NVTX(...) #endif #endif /* NVTX_H */