def.bzl (193 lines of code) (raw):
load(
"@local_config_cuda//cuda:build_defs.bzl",
"cuda_default_copts",
_if_cuda = "if_cuda",
)
load(
"@local_config_rocm//rocm:build_defs.bzl",
"rocm_default_copts",
_if_rocm = "if_rocm",
)
if_rocm = _if_rocm
if_cuda = _if_cuda
def rpm_library(
name,
hdrs,
include_path=None,
lib_path=None,
rpms=None,
static_lib=None,
static_libs=[], # multi static libs, do not add to cc_library, provide .a filegroup
shared_lib=None,
shared_libs=[],
bins=[],
include_prefix=None,
static_link=False,
deps=[],
header_only=False,
tags={},
**kwargs):
hdrs = [ "include/" + hdr for hdr in hdrs ]
outs = [] + hdrs
if static_lib:
outs.append(static_lib)
if shared_lib :
outs.append(shared_lib)
if not rpms:
rpms = ["@" + name + "//file:file"]
bash_cmd = "mkdir " + name + " && cd " + name
bash_cmd += " && for e in $(SRCS); do rpm2cpio ../$$e | cpio -idm; done"
if include_path != None:
if header_only:
bash_cmd += "&& cp -rf " + include_path + "/* ../$(@D)/"
else:
bash_cmd += "&& cp -rf " + include_path + "/* ../$(@D)/include"
if len(static_libs) > 0:
# extract all .a files to its own directory in case .o file conflict, and ar them together to target .a file.
bash_cmd += "&& for a in " + " ".join(static_libs) + "; do d=$${a%.a} && mkdir $$d && cd $$d && ar x ../" + lib_path + "$$a && cd -; done && ar rc ../$(@D)/" + static_lib + " */*.o"
elif static_lib:
bash_cmd += "&& cp -L " + lib_path + "/*.a" + " ../$(@D)/"
if shared_lib:
bash_cmd += "&& cp -L " + lib_path + "/" + shared_lib + " ../$(@D) && patchelf --set-soname " + shared_lib + " ../$(@D)/" + shared_lib
for share_lib in shared_libs:
outs.append(share_lib)
bash_cmd += "&& cp -L " + lib_path + "/" + share_lib + " ../$(@D) && patchelf --set-soname " + share_lib + " ../$(@D)/" + share_lib
for path in bins:
outs.append(path)
bash_cmd += "&& cp -rL " + path + " ../$(@D)"
bash_cmd += " && cd -"
native.genrule(
name = name + "_files",
srcs = rpms,
outs = outs,
cmd = bash_cmd,
visibility = ["//visibility:public"],
tags=tags,
)
hdrs_fg_target = name + "_hdrs_fg"
native.filegroup(
name = hdrs_fg_target,
srcs = hdrs,
)
if static_lib:
native.filegroup(
name = name + "_static",
srcs = [static_lib],
visibility = ["//visibility:public"],
)
srcs = []
shared_files = shared_libs + (shared_lib and [shared_lib] or [])
if shared_files:
shared_filegroup = name + "_shared"
native.filegroup(
name = shared_filegroup,
srcs = shared_files,
visibility = ["//visibility:public"],
)
if shared_libs:
srcs.append(shared_filegroup)
if bins:
bins_filegroup = name + "_bins"
native.filegroup(
name = bins_filegroup,
srcs = bins,
visibility = ["//visibility:public"],
tags=tags,
)
if static_lib == None:
native.cc_library(
name = name,
hdrs = [hdrs_fg_target],
srcs = shared_files,
deps = deps,
strip_include_prefix = "include",
include_prefix = include_prefix,
visibility = ["//visibility:public"],
**kwargs
)
else:
import_target = name + "_import"
alwayslink = static_lib!=None
native.cc_import(
name = import_target,
static_library = static_lib,
shared_library = shared_lib,
alwayslink=alwayslink,
visibility = ["//visibility:public"],
)
native.cc_library(
name = name,
hdrs = [hdrs_fg_target],
srcs = srcs,
deps = deps + [import_target],
visibility = ["//visibility:public"],
strip_include_prefix = "include",
include_prefix = include_prefix,
**kwargs
)
def copts():
return [
"-DTORCH_CUDA",
] + if_cuda([
"-DTHRUST_IGNORE_CUB_VERSION_CHECK",
"-DUSE_C10D_NCCL",
"-DC10_CUDA_NO_CMAKE_CONFIGURE_FILE",
]) + if_rocm([
"-x", "rocm",
])
def cuda_copts():
# add --objdir-as-tempdir to rm tmp file after build
return copts() + cuda_default_copts() + if_cuda(["-nvcc_options=objdir-as-tempdir"])
def rocm_copts():
return copts() + rocm_default_copts() + if_rocm(["-Wc++17-extensions"])
def any_cuda_copts():
return copts() + cuda_default_copts() + if_cuda(["-nvcc_options=objdir-as-tempdir"]) + rocm_default_copts() + if_rocm(["-Wc++17-extensions"])
def gen_cpp_code(name, elements_list, template_header, template, template_tail,
element_per_file = 1, suffix=".cpp"):
bases = []
base = 1
for i in range(len(elements_list)):
base = len(elements_list[i]) * base
base_tmp = base
for i in range(len(elements_list)):
base_tmp = base_tmp // len(elements_list[i])
bases.append(base_tmp)
files = []
current = 0
count = 0
current_str = template_header
for i in range(base):
replace_elements_list = []
num = i
for j in range(len(bases)):
this_element = elements_list[j][num // bases[j]]
if type(this_element) == 'tuple':
replace_elements_list.extend(this_element)
else:
replace_elements_list.append(this_element)
num %= bases[j]
# for all permutations here
if type(replace_elements_list[0]) == "tuple":
replace_elements_list = replace_elements_list[0]
else:
replace_elements_list = tuple(replace_elements_list)
current_str += template.format(*replace_elements_list)
current += 1
if current == element_per_file or i == base - 1:
cpp_name = name + "_" + str(count)
count += 1
file_name = cpp_name + suffix
content = current_str + template_tail
native.genrule(
name = cpp_name,
srcs = [],
outs = [file_name],
cmd = "cat > $@ << 'EOF'\n" + content + "EOF",
)
current = 0
current_str = template_header
files.append(cpp_name)
native.filegroup(
name = name,
srcs = files
)