python/pyfury/codegen.py (124 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import atexit
import linecache
import os
import uuid
from typing import List, Callable, Union
from pyfury.resolver import NULL_FLAG, NOT_NULL_VALUE_FLAG
from pyfury.error import CompileError
_type_mapping = {
bool: ("write_bool", "read_bool", "write_nullable_pybool", "read_nullable_pybool"),
int: (
"write_varint64",
"read_varint64",
"write_nullable_pyint64",
"read_nullable_pyint64",
),
float: (
"write_double",
"read_double",
"write_nullable_pyfloat64",
"read_nullable_pyfloat64",
),
str: ("write_string", "read_string", "write_nullable_pystr", "read_nullable_pystr"),
}
def gen_write_nullable_basic_stmts(
buffer: str,
value: str,
type_: type,
) -> List[str]:
methods = _type_mapping[type_]
from pyfury import ENABLE_FURY_CYTHON_SERIALIZATION
if ENABLE_FURY_CYTHON_SERIALIZATION:
return [f"{methods[2]}({buffer}, {value})"]
return [
f"if {value} is None:",
f" {buffer}.write_int8({NULL_FLAG})",
"else: ",
f" {buffer}.write_int8({NOT_NULL_VALUE_FLAG})",
f" {buffer}.{methods[0]}({value})",
]
def gen_read_nullable_basic_stmts(
buffer: str,
type_: type,
set_action: Callable[[str], str],
) -> List[str]:
methods = _type_mapping[type_]
from pyfury import ENABLE_FURY_CYTHON_SERIALIZATION
if ENABLE_FURY_CYTHON_SERIALIZATION:
return [set_action(f"{methods[3]}({buffer})")]
read_value = f"{buffer}.{methods[1]}()"
return [
f"if {buffer}.read_int8() == {NULL_FLAG}:",
f" {set_action('None')}",
"else: ",
f" {set_action(read_value)}",
]
def compile_function(
function_name: str,
params: List[str],
stmts: List[str],
context: dict,
):
from pyfury import ENABLE_FURY_CYTHON_SERIALIZATION
if ENABLE_FURY_CYTHON_SERIALIZATION:
from pyfury import _serialization
context["write_nullable_pybool"] = _serialization.write_nullable_pybool
context["read_nullable_pybool"] = _serialization.read_nullable_pybool
context["write_nullable_pyint64"] = _serialization.write_nullable_pyint64
context["read_nullable_pyint64"] = _serialization.read_nullable_pyint64
context["write_nullable_pyfloat64"] = _serialization.write_nullable_pyfloat64
context["read_nullable_pyfloat64"] = _serialization.read_nullable_pyfloat64
context["write_nullable_pystr"] = _serialization.write_nullable_pystr
context["read_nullable_pystr"] = _serialization.read_nullable_pystr
stmts = [f"{ident(statement)}" for statement in stmts]
stmts.insert(0, f"def {function_name}({', '.join(params)}):")
stmts = [f"{statement} # line {idx + 1}" for idx, statement in enumerate(stmts)]
code = "\n".join(stmts)
filename = _generate_filename(function_name)
code_dir = _get_code_dir()
if code_dir:
filename = os.path.join(code_dir, filename)
with open(filename, "w") as f:
f.write(code)
f.flush()
if _delete_code_on_exit():
atexit.register(os.remove, filename)
try:
compiled = compile(code, filename, "exec")
except Exception as e:
raise CompileError(f"Failed to compile code:\n{code}") from e
exec(compiled, context, context)
# See https://stackoverflow.com/questions/64879414/how-does-attrs-fool-the-debugger-to-step-into-auto-generated-code # noqa: E501
# In order of debuggers like PDB being able to step through the code,
# we add a fake linecache entry.
linecache.cache[filename] = (
len(code),
None,
code.splitlines(True),
filename,
)
return code, context[function_name]
# Based on https://github.com/python-attrs/attrs/blob/32fb12789e5cba4b2e71c09e47196b10763ddd7d/src/attr/_make.py#L1863 # noqa: E501
def _generate_filename(func_name):
"""
Create a "filename" suitable for a function being generated.
"""
unique_id = uuid.uuid4()
extra = "0"
count = 1
while True:
filename = f"fury_generated_{func_name}_{extra}.py"
# To handle concurrency we essentially "reserve" our spot in
# the linecache with a dummy line. The caller can then
# set this value correctly.
cache_line = (1, None, [str(unique_id)], filename)
if linecache.cache.setdefault(filename, cache_line) == cache_line:
return filename
# Looks like this spot is taken. Try again.
count += 1
extra = "{0}".format(count)
def _get_code_dir():
code_dir = os.environ.get("FURY_CODE_DIR")
if code_dir is not None and not os.path.exists(code_dir):
os.makedirs(code_dir)
return code_dir
def _delete_code_on_exit():
return os.environ.get("DELETE_CODE_ON_EXIT", "True").lower() in ("true", "1")
def ident_lines(lines: Union[List[str], str]):
is_str = type(lines) is str
if is_str:
lines = lines.split("\n")
lines = [ident(line) for line in lines]
return lines if not is_str else "\n".join(lines)
def ident(line: str):
assert type(line) is str, type(line)
return " " * 4 + line