open_source/bazel/arch_select.bzl (121 lines of code) (raw):
# to wrapper target relate with different system config
load("@pip_cpu_torch//:requirements.bzl", requirement_cpu="requirement")
load("@pip_arm_torch//:requirements.bzl", requirement_arm="requirement")
load("@pip_gpu_torch//:requirements.bzl", requirement_gpu="requirement")
load("@pip_gpu_cuda12_torch//:requirements.bzl", requirement_gpu_cuda12="requirement")
load("@pip_gpu_rocm_torch//:requirements.bzl", requirement_gpu_rocm="requirement")
def requirement(names):
for name in names:
native.py_library(
name = name,
deps = select({
"@//:using_cuda12": [requirement_gpu_cuda12(name)],
"@//:using_cuda11": [requirement_gpu(name)],
"@//:using_rocm": [requirement_gpu_rocm(name)],
"@//:using_arm": [requirement_arm(name)],
"//conditions:default": [requirement_cpu(name)],
}),
visibility = ["//visibility:public"],
)
def cache_store_deps():
native.alias(
name = "cache_store_arch_select_impl",
actual = "//maga_transformer/cpp/disaggregate/cache_store:cache_store_base_impl"
)
def embedding_arpc_deps():
native.alias(
name = "embedding_arpc_deps",
actual = "//maga_transformer/cpp/embedding_engine:embedding_engine_arpc_server_impl"
)
def subscribe_deps():
native.alias(
name = "subscribe_deps",
actual = "//maga_transformer/cpp/disaggregate/load_balancer/subscribe:subscribe_service_impl"
)
def whl_deps():
return select({
"@//:using_cuda12": ["torch==2.1.2+cu121"],
"@//:using_cuda11": ["torch==2.1.2+cu118"],
"@//:using_rocm": ["torch==2.1.2", "pyyaml"],
"//conditions:default": ["torch==2.1.2"],
})
def platform_deps():
return select({
"@//:using_arm": [],
"@//:using_rocm": ["pyyaml"],
"//conditions:default": ["decord==0.6.0"],
})
def torch_deps():
deps = select({
"@//:using_rocm": [
"@torch_rocm//:torch_api",
"@torch_rocm//:torch",
"@torch_rocm//:torch_libs",
],
"@//:using_arm": [
"@torch_2.3_py310_cpu_aarch64//:torch_api",
"@torch_2.3_py310_cpu_aarch64//:torch",
"@torch_2.3_py310_cpu_aarch64//:torch_libs",
],
"@//:using_cuda": [
"@torch_2.6_py310_cuda//:torch_api",
"@torch_2.6_py310_cuda//:torch",
"@torch_2.6_py310_cuda//:torch_libs",
],
"//conditions:default": [
"@torch_2.1_py310_cpu//:torch_api",
"@torch_2.1_py310_cpu//:torch",
"@torch_2.1_py310_cpu//:torch_libs",
]
})
return deps
def fa_deps():
native.alias(
name = "fa",
actual = "@flash_attention//:fa"
)
native.alias(
name = "fa_hdrs",
actual = "@flash_attention//:fa_hdrs",
)
def flashinfer_deps():
native.alias(
name = "flashinfer",
actual = "@flashinfer//:flashinfer"
)
def flashmla_deps():
native.alias(
name = "flashmla",
actual = "@flashmla//:flashmla"
)
def kernel_so_deps():
return select({
"@//:using_cuda": [":libmmha1_so", ":libmmha2_so", ":libdmmha_so", ":libfa_so", ":libfpA_intB_so", ":libint8_gemm_so", ":libmoe_so", ":libmoe_sm90_so", ":libflashinfer_single_prefill_so", ":libflashinfer_single_decode_so", ":libflashinfer_batch_paged_prefill_so", ":libflashinfer_batch_paged_decode_so", ":libflashinfer_batch_ragged_prefill_so", ":libflashinfer_sm90_so", ":libdeepgemm_dpsk_inst_so", ":libdeepgemm_qwen_inst_so"],
"@//:using_rocm": [":libmmha1_so", ":libmmha2_so", ":libdmmha_so"],
"//conditions:default":[],
})
def arpc_deps():
native.cc_library(
name = ""
)
def trt_plugins():
native.alias(
name = "trt_plugins",
actual = select({
"@//:using_cuda12": "//maga_transformer/cpp/trt_plugins:trt_plugins",
"//conditions:default": "//maga_transformer/cpp/trt_plugins:trt_plugins",
})
)
def cuda_register():
native.alias(
name = "cuda_register",
actual = select({
"//conditions:default": "//maga_transformer/cpp/devices/cuda_impl:gpu_register",
})
)
def triton_deps(names):
return select({
"//conditions:default": [],
})
def internal_deps():
return []