optimum/graphcore/custom_ops/utils.py (44 lines of code) (raw):
# Copyright (c) 2022 Graphcore Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Source adapted from: https://github.com/graphcore/examples-utils/blob/v3.4/examples_utils/load_lib_utils/load_lib_utils.py
import ctypes
import hashlib
import os
from unittest.mock import patch
import cppimport
from cppimport.importer import (
build_safely,
is_build_needed,
setup_module_data,
)
from .sdk_version_hash import sdk_version_hash
__all__ = ["load_lib"]
settings = {"file_exts": (".cpp",)}
def _calc_cur_checksum_with_sdk_version():
version = sdk_version_hash()
def func(file_lst, module_data):
text = b""
for filepath in file_lst:
with open(filepath, "rb") as f:
text += f.read()
cpphash = hashlib.md5(text).hexdigest()
hash = f"SDK-VERSION-{version}-{cpphash}"
return hash
return func
def _build(filepath, timeout: int = 5 * 60):
if not os.path.exists(filepath):
raise FileNotFoundError(f"File does not exist: {filepath}")
filepath = os.path.abspath(filepath)
old_timeout = cppimport.settings.get("lock_timeout", 5 * 60)
try:
cppimport.settings["lock_timeout"] = timeout
# TODO: remove hack once ticket resolved: https://github.com/tbenthompson/cppimport/issues/76
# TODO: A hack to include the SDK version hash as part of the cppimport hash
with patch("cppimport.checksum._calc_cur_checksum", new=_calc_cur_checksum_with_sdk_version()):
fullname = os.path.splitext(os.path.basename(filepath))[0]
module_data = setup_module_data(fullname, filepath)
if is_build_needed(module_data):
build_safely(filepath, module_data)
binary_path = module_data["ext_path"]
finally:
cppimport.settings["lock_timeout"] = old_timeout
return binary_path
def load_lib(filepath: str, timeout: int = 5 * 60):
"""Compile a C++ source file using `cppimport`, load the shared library into the process using `ctypes` and
return it.
Compilation is not triggered if an existing binary matches the source file hash and Graphcore SDK version which is
embedded in the binary file.
`cppimport` is used to compile the source which uses a special comment in the C++ file that includes the
compilation parameters. Here is an example of such a comment which defines compiler flags, additional sources files
and library options (see `cppimport` documentation for more info):
```
/*
<%
cfg['sources'] = ['another_source.cpp']
cfg['extra_compile_args'] = ['-std=c++14', '-fPIC', '-O2', '-DONNX_NAMESPACE=onnx', '-Wall']
cfg['libraries'] = ['popart', 'poplar', 'popops', 'poputil', 'popnn']
setup_pybind11(cfg)
%>
*/
```
Its also recommended to include the cppimport header at the top of the source file `\\ cppimport` to indicate that
it will be loaded via cppimport and so the `load_lib_all` function will build it.
Parameters:
filepath (str): File path of the C++ source file
timeout (int): Timeout time if cannot obtain lock to compile the source
Returns:
lib: library instance. Output of `ctypes.cdll.LoadLibrary`
"""
binary_path = _build(filepath, timeout)
lib = ctypes.cdll.LoadLibrary(binary_path)
return lib