scripts/init-kernel.py (99 lines of code) (raw):
# This script creates the necessary files for a new kernel example in the specified directory.
#
# Example Usage:
# $ uv run scripts/init-kernel.py relu
#
# Created directory: relu
#
# relu/
# ├── relu_kernel/
# │ └── relu.cu
# ├── tests/
# │ ├── __init__.py
# │ └── test_relu.py
# ├── torch-ext/
# │ ├── relu/
# │ │ └── __init__.py
# │ ├── torch_binding.cpp
# │ └── torch_binding.h
# ├── build.toml
# └── flake.nix
#
# ✓ Success! All files for the ReLU example have been created successfully.
#
# Next steps:
# 1. Build the kernel: cd relu && git add . && nix develop -L
# 2. Run the tests: pytest -vv tests/
import os
import argparse
import pathlib
class Colors:
HEADER = "\033[95m"
BLUE = "\033[94m"
CYAN = "\033[96m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
RED = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
GREY = "\033[90m"
def create_file_with_content(file_path: str, content: str):
"""Creates a file at 'file_path' with the specified content."""
directory = os.path.dirname(file_path)
if directory and not os.path.exists(directory):
os.makedirs(directory)
with open(file_path, "w") as f:
f.write(content)
# Generate a tree view of the created files
def print_tree(directory: str, prefix: str = ""):
entries = sorted(os.listdir(directory))
# Process directories first, then files
dirs = [e for e in entries if os.path.isdir(os.path.join(directory, e))]
files = [e for e in entries if os.path.isfile(os.path.join(directory, e))]
# Process all items except the last one
count = len(dirs) + len(files)
# Print directories
for i, dirname in enumerate(dirs):
is_last_dir = i == len(dirs) - 1 and len(files) == 0
connector = "└── " if is_last_dir else "├── "
print(
f" {prefix}{connector}{Colors.BOLD}{Colors.BLUE}{dirname}/{Colors.ENDC}"
)
# Prepare the prefix for the next level
next_prefix = prefix + (" " if is_last_dir else "│ ")
print_tree(os.path.join(directory, dirname), next_prefix)
# Print files
for i, filename in enumerate(files):
is_last = i == len(files) - 1
connector = "└── " if is_last else "├── "
file_color = ""
print(f" {prefix}{connector}{file_color}{filename}{Colors.ENDC}")
def main():
# Get the directory where this script is located
script_dir = pathlib.Path(__file__).parent.resolve().parent.resolve()
# Create argument parser
parser = argparse.ArgumentParser(
description="Create ReLU example files in the specified directory"
)
parser.add_argument(
"target_dir", help="Target directory where files will be created"
)
args = parser.parse_args()
# Get the target directory from arguments
target_dir = args.target_dir
# Create the target directory if it doesn't exist
if not os.path.exists(target_dir):
os.makedirs(target_dir)
print(
f"\n{Colors.CYAN}{Colors.BOLD}Created directory: {Colors.BOLD}{target_dir}{Colors.ENDC}\n"
)
else:
print(
f"\n{Colors.CYAN}{Colors.BOLD}Directory already exists: {Colors.BOLD}{target_dir}{Colors.ENDC}\n"
)
# get files from examples/relu
relu_dir = script_dir / "examples" / "relu"
for root, _, files in os.walk(relu_dir):
for file in files:
file_path = os.path.join(root, file)
with open(file_path, "r") as f:
content = f.read()
# Replace kernel-builder.url with path:../ in flake.nix
if file_path.endswith("flake.nix"):
kernel_builder_url_start = content.find("kernel-builder.url =")
kernel_builder_url_end = content.find(";", kernel_builder_url_start)
content = (
content[:kernel_builder_url_start]
+ 'kernel-builder.url = "path:../"'
+ content[kernel_builder_url_end:]
)
target_file = file_path.replace(str(relu_dir), target_dir)
create_file_with_content(target_file, content)
print(f" {Colors.BOLD}{target_dir}/{Colors.ENDC}")
print_tree(target_dir)
print(
f"\n{Colors.GREEN}{Colors.BOLD}✓ Success!{Colors.ENDC} All files for the ReLU example have been created successfully."
)
print(f"\n{Colors.CYAN}{Colors.BOLD}Next steps:{Colors.ENDC}")
commands = [
"nix run nixpkgs#cachix -- use huggingface",
f"cd {target_dir}",
"git add .",
"nix develop -L",
]
for index, command in enumerate(commands, start=1):
print(
f" {Colors.YELLOW}{index}.{Colors.ENDC} {Colors.BOLD}{command}{Colors.ENDC}"
)
print(
f"\none line build:\n{Colors.GREY}{Colors.BOLD}{' && '.join(commands)}{Colors.ENDC}{Colors.ENDC}"
)
print(f"\n{Colors.CYAN}{Colors.BOLD}Run the tests{Colors.ENDC}")
print(
f" {Colors.YELLOW}{1}.{Colors.ENDC} {Colors.BOLD}pytest -vv tests/{Colors.ENDC}"
)
print("")
if __name__ == "__main__":
main()