setup.py (130 lines of code) (raw):
#!/usr/bin/env python3
# Welcome to the PyTorch Captum setup.py.
#
# Environment variables for feature toggles:
#
# BUILD_INSIGHTS
# enables Captum Insights build via yarn
#
import os
import re
import subprocess
import sys
from setuptools import find_packages, setup
REQUIRED_MAJOR = 3
REQUIRED_MINOR = 6
# Check for python version
if sys.version_info < (REQUIRED_MAJOR, REQUIRED_MINOR):
error = (
"Your version of python ({major}.{minor}) is too old. You need "
"python >= {required_major}.{required_minor}."
).format(
major=sys.version_info.major,
minor=sys.version_info.minor,
required_minor=REQUIRED_MINOR,
required_major=REQUIRED_MAJOR,
)
sys.exit(error)
# Allow for environment variable checks
def check_env_flag(name, default=""):
return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
BUILD_INSIGHTS = check_env_flag("BUILD_INSIGHTS")
VERBOSE_SCRIPT = True
for arg in sys.argv:
if arg == "-q" or arg == "--quiet":
VERBOSE_SCRIPT = False
def report(*args):
if VERBOSE_SCRIPT:
print(*args)
else:
pass
INSIGHTS_REQUIRES = ["flask", "ipython", "ipywidgets", "jupyter", "flask-compress"]
INSIGHTS_FILE_SUBDIRS = [
"insights/attr_vis/frontend/build",
"insights/attr_vis/models",
"insights/attr_vis/widget/static",
]
TUTORIALS_REQUIRES = INSIGHTS_REQUIRES + ["torchtext", "torchvision"]
TEST_REQUIRES = ["pytest", "pytest-cov", "parameterized"]
DEV_REQUIRES = (
INSIGHTS_REQUIRES
+ TEST_REQUIRES
+ [
"black==21.4b2",
"flake8",
"sphinx",
"sphinx-autodoc-typehints",
"sphinxcontrib-katex",
"mypy>=0.760",
"usort==0.6.4",
"ufmt",
"scikit-learn",
"annoy",
]
)
# get version string from module
with open(os.path.join(os.path.dirname(__file__), "captum/__init__.py"), "r") as f:
version = re.search(r"__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M).group(1)
report("-- Building version " + version)
# read in README.md as the long description
with open("README.md", "r") as fh:
long_description = fh.read()
# optionally build Captum Insights via yarn
def build_insights():
report("-- Building Captum Insights")
command = "./scripts/build_insights.sh"
report("Running: " + command)
subprocess.check_call(command)
# explore paths under root and subdirs to gather package files
def get_package_files(root, subdirs):
paths = []
for subroot in subdirs:
paths.append(os.path.join(subroot, "*"))
for path, dirs, _ in os.walk(os.path.join(root, subroot)):
for d in dirs:
paths.append(os.path.join(path, d, "*")[len(root) + 1 :])
return paths
if __name__ == "__main__":
if BUILD_INSIGHTS:
build_insights()
package_files = get_package_files("captum", INSIGHTS_FILE_SUBDIRS)
setup(
name="captum",
version=version,
description="Model interpretability for PyTorch",
author="PyTorch Team",
license="BSD-3",
url="https://captum.ai",
project_urls={
"Documentation": "https://captum.ai",
"Source": "https://github.com/pytorch/captum",
"conda": "https://anaconda.org/pytorch/captum",
},
keywords=[
"Model Interpretability",
"Model Understanding",
"Feature Importance",
"Neuron Importance",
"PyTorch",
],
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: BSD License",
"Programming Language :: Python :: 3 :: Only",
"Topic :: Scientific/Engineering",
],
long_description=long_description,
long_description_content_type="text/markdown",
python_requires=">=3.6",
install_requires=["matplotlib", "numpy", "torch>=1.6"],
packages=find_packages(exclude=("tests", "tests.*")),
extras_require={
"dev": DEV_REQUIRES,
"insights": INSIGHTS_REQUIRES,
"test": TEST_REQUIRES,
"tutorials": TUTORIALS_REQUIRES,
},
package_data={"captum": package_files},
data_files=[
(
"share/jupyter/nbextensions/jupyter-captum-insights",
[
"captum/insights/attr_vis/widget/static/extension.js",
"captum/insights/attr_vis/widget/static/index.js",
],
),
(
"etc/jupyter/nbconfig/notebook.d",
["captum/insights/attr_vis/widget/jupyter-captum-insights.json"],
),
],
)