include/tuner/nccl_ofi_tuner_common.h (43 lines of code) (raw):
/*
* Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All rights reserved.
*/
#ifndef NCCL_OFI_TUNER_COMMON_H_
#define NCCL_OFI_TUNER_COMMON_H_
#include "config.h"
#include <linux/limits.h>
#include <nccl/tuner.h>
typedef struct nccl_ofi_tuner_context nccl_ofi_tuner_context_t;
/* region base vs. model base */
enum nccl_ofi_tuner_type {
NCCL_OFI_TUNER_TYPE_REGION = 0,
NCCL_OFI_TUNER_TYPE_MODEL
};
/* platform type for tuner respective */
enum nccl_ofi_tuner_platform {
NCCL_OFI_TUNER_P5_P5E = 0,
NCCL_OFI_TUNER_P5EN,
NCCL_OFI_TUNER_UNKNOWN,
NCCL_OFI_TUNER_PLATFORM_MAX = NCCL_OFI_TUNER_UNKNOWN
};
struct nccl_ofi_tuner_context {
enum nccl_ofi_tuner_type type;
/* pointer to tuner type ("Region" or "Model") specific context data */
void *type_ctx;
/*
* tuner type ("Region" or "Model") specific functions
*/
ncclResult_t (*init_internal)(nccl_ofi_tuner_context_t *ctx,
enum nccl_ofi_tuner_platform platform,
size_t nRanks,
size_t nNodes);
ncclResult_t (*get_coll_info_internal_v3)(nccl_ofi_tuner_context_t *ctx,
ncclFunc_t collType,
size_t nBytes,
int numPipeOps,
float **collCostTable,
int numAlgo,
int numProto,
int *nChannels);
ncclResult_t (*get_coll_info_internal_v2)(nccl_ofi_tuner_context_t *ctx,
ncclFunc_t collType,
size_t nBytes,
int collNetSupport,
int nvlsSupport,
int numPipeOps,
int *algorithm,
int *protocol,
int* nChannels);
ncclResult_t (*destroy_internal)(nccl_ofi_tuner_context_t *ctx);
};
#endif /* NCCL_OFI_TUNER_COMMON_H_ */