tinynn/graph/_utils.py (254 lines of code) (raw):
# The following code is based on forbiddenfruit.
# URL: https://github.com/clarete/forbiddenfruit
#
# Copyright (c) 2013-2020 Lincoln de Sousa <lincoln@clarete.li>
#
# This program is licensed under MIT.
#
# MIT
# ---
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import ctypes
from functools import wraps
Py_ssize_t = ctypes.c_int64 if ctypes.sizeof(ctypes.c_void_p) == 8 else ctypes.c_int32
class PyObject(ctypes.Structure):
def incref(self):
self.ob_refcnt += 1
def decref(self):
self.ob_refcnt -= 1
class PyFile(ctypes.Structure):
pass
PyObject_p = ctypes.py_object
Inquiry_p = ctypes.CFUNCTYPE(ctypes.c_int, PyObject_p)
# return type is void* to allow ctypes to convert python integers to
# plain PyObject*
UnaryFunc_p = ctypes.CFUNCTYPE(ctypes.py_object, PyObject_p)
BinaryFunc_p = ctypes.CFUNCTYPE(ctypes.py_object, PyObject_p, PyObject_p)
TernaryFunc_p = ctypes.CFUNCTYPE(ctypes.py_object, PyObject_p, PyObject_p, PyObject_p)
LenFunc_p = ctypes.CFUNCTYPE(Py_ssize_t, PyObject_p)
SSizeArgFunc_p = ctypes.CFUNCTYPE(ctypes.py_object, PyObject_p, Py_ssize_t)
SSizeObjArgProc_p = ctypes.CFUNCTYPE(ctypes.c_int, PyObject_p, Py_ssize_t, PyObject_p)
ObjObjProc_p = ctypes.CFUNCTYPE(ctypes.c_int, PyObject_p, PyObject_p)
ObjObjArgProc_p = ctypes.CFUNCTYPE(ctypes.c_int, PyObject_p, PyObject_p, PyObject_p)
FILE_p = ctypes.POINTER(PyFile)
def get_not_implemented():
namespace = {}
name = "_Py_NotImplmented"
not_implemented = ctypes.cast(ctypes.pythonapi._Py_NotImplementedStruct, ctypes.py_object)
ctypes.pythonapi.PyDict_SetItem(ctypes.py_object(namespace), ctypes.py_object(name), not_implemented)
return namespace[name]
# address of the _Py_NotImplementedStruct singleton
NotImplementedRet = get_not_implemented()
class PyNumberMethods(ctypes.Structure):
_fields_ = [
('nb_add', BinaryFunc_p),
('nb_subtract', BinaryFunc_p),
('nb_multiply', BinaryFunc_p),
('nb_remainder', BinaryFunc_p),
('nb_divmod', BinaryFunc_p),
('nb_power', BinaryFunc_p),
('nb_negative', UnaryFunc_p),
('nb_positive', UnaryFunc_p),
('nb_absolute', UnaryFunc_p),
('nb_bool', Inquiry_p),
('nb_invert', UnaryFunc_p),
('nb_lshift', BinaryFunc_p),
('nb_rshift', BinaryFunc_p),
('nb_and', BinaryFunc_p),
('nb_xor', BinaryFunc_p),
('nb_or', BinaryFunc_p),
('nb_int', UnaryFunc_p),
('nb_reserved', ctypes.c_void_p),
('nb_float', UnaryFunc_p),
('nb_inplace_add', BinaryFunc_p),
('nb_inplace_subtract', BinaryFunc_p),
('nb_inplace_multiply', BinaryFunc_p),
('nb_inplace_remainder', BinaryFunc_p),
('nb_inplace_power', TernaryFunc_p),
('nb_inplace_lshift', BinaryFunc_p),
('nb_inplace_rshift', BinaryFunc_p),
('nb_inplace_and', BinaryFunc_p),
('nb_inplace_xor', BinaryFunc_p),
('nb_inplace_or', BinaryFunc_p),
('nb_floor_divide', BinaryFunc_p),
('nb_true_divide', BinaryFunc_p),
('nb_inplace_floor_divide', BinaryFunc_p),
('nb_inplace_true_divide', BinaryFunc_p),
('nb_index', BinaryFunc_p),
('nb_matrix_multiply', BinaryFunc_p),
('nb_inplace_matrix_multiply', BinaryFunc_p),
]
class PySequenceMethods(ctypes.Structure):
_fields_ = [
('sq_length', LenFunc_p),
('sq_concat', BinaryFunc_p),
('sq_repeat', SSizeArgFunc_p),
('sq_item', SSizeArgFunc_p),
('was_sq_slice', ctypes.c_void_p),
('sq_ass_item', SSizeObjArgProc_p),
('was_sq_ass_slice', ctypes.c_void_p),
('sq_contains', ObjObjProc_p),
('sq_inplace_concat', BinaryFunc_p),
('sq_inplace_repeat', SSizeArgFunc_p),
]
class PyMappingMethods(ctypes.Structure):
_fields_ = [
('mp_length', LenFunc_p),
('mp_subscript', BinaryFunc_p),
('mp_ass_subscript', ObjObjArgProc_p),
]
class PyTypeObject(ctypes.Structure):
pass
class PyAsyncMethods(ctypes.Structure):
pass
PyObject._fields_ = [
('ob_refcnt', Py_ssize_t),
('ob_type', ctypes.POINTER(PyTypeObject)),
]
PyTypeObject._fields_ = [
# varhead
('ob_base', PyObject),
('ob_size', Py_ssize_t),
# declaration
('tp_name', ctypes.c_char_p),
('tp_basicsize', Py_ssize_t),
('tp_itemsize', Py_ssize_t),
('tp_dealloc', ctypes.CFUNCTYPE(None, PyObject_p)),
('printfunc', ctypes.CFUNCTYPE(ctypes.c_int, PyObject_p, FILE_p, ctypes.c_int)),
('getattrfunc', ctypes.CFUNCTYPE(PyObject_p, PyObject_p, ctypes.c_char_p)),
('setattrfunc', ctypes.CFUNCTYPE(ctypes.c_int, PyObject_p, ctypes.c_char_p, PyObject_p)),
('tp_as_async', ctypes.CFUNCTYPE(PyAsyncMethods)),
('tp_repr', ctypes.CFUNCTYPE(PyObject_p, PyObject_p)),
('tp_as_number', ctypes.POINTER(PyNumberMethods)),
('tp_as_sequence', ctypes.POINTER(PySequenceMethods)),
('tp_as_mapping', ctypes.POINTER(PyMappingMethods)),
('tp_hash', ctypes.CFUNCTYPE(ctypes.c_int64, PyObject_p)),
('tp_call', ctypes.CFUNCTYPE(PyObject_p, PyObject_p, PyObject_p, PyObject_p)),
('tp_str', ctypes.CFUNCTYPE(PyObject_p, PyObject_p)),
('tp_getattro', ctypes.c_void_p), # Type not declared yet
('tp_setattro', ctypes.c_void_p), # Type not declared yet
('tp_as_buffer', ctypes.c_void_p), # Type not declared yet
('tp_flags', ctypes.c_void_p), # Type not declared yet
('tp_doc', ctypes.c_void_p), # Type not declared yet
('tp_traverse', ctypes.c_void_p), # Type not declared yet
('tp_clear', ctypes.c_void_p), # Type not declared yet
('tp_richcompare', ctypes.c_void_p), # Type not declared yet
('tp_weaklistoffset', ctypes.c_void_p), # Type not declared yet
('tp_iter', ctypes.c_void_p), # Type not declared yet
('iternextfunc', ctypes.c_void_p), # Type not declared yet
('tp_methods', ctypes.c_void_p), # Type not declared yet
('tp_members', ctypes.c_void_p), # Type not declared yet
('tp_getset', ctypes.c_void_p), # Type not declared yet
('tp_base', ctypes.c_void_p), # Type not declared yet
('tp_dict', ctypes.c_void_p), # Type not declared yet
('tp_descr_get', ctypes.c_void_p), # Type not declared yet
('tp_descr_set', ctypes.c_void_p), # Type not declared yet
('tp_dictoffset', ctypes.c_void_p), # Type not declared yet
('tp_init', ctypes.c_void_p), # Type not declared yet
('tp_alloc', ctypes.c_void_p), # Type not declared yet
('tp_new', ctypes.CFUNCTYPE(PyObject_p, PyObject_p, PyObject_p, ctypes.c_void_p)),
# ...
]
class SlotsPointer(PyObject):
_fields_ = [('dict', ctypes.POINTER(PyObject))]
_decref = ctypes.pythonapi.Py_DecRef
_decref.argtypes = [ctypes.py_object]
_decref.restype = None
_incref = ctypes.pythonapi.Py_IncRef
_incref.argtypes = [ctypes.py_object]
_incref.restype = None
def proxy_builtin(klass):
name = klass.__name__
slots = getattr(klass, '__dict__', name)
pointer = SlotsPointer.from_address(id(slots))
namespace = {}
ctypes.pythonapi.PyDict_SetItem(
ctypes.py_object(namespace),
ctypes.py_object(name),
pointer.dict,
)
return namespace[name]
def patch_new(base_cls, func):
assert callable(func)
@wraps(func)
def wrapper(*args, **kwargs):
"""
This wrapper returns the address of the resulting object as a
python integer which is then converted to a pointer by ctypes
"""
try:
return func(*args, **kwargs)
except NotImplementedError:
return NotImplementedRet
orig_mp_funcs = []
orig_gm_funcs = []
tp_as_name = "tp_new"
tyobj = PyTypeObject.from_address(id(base_cls))
_incref(tyobj)
struct_ty = PyTypeObject
# find the C function type
for fname, ftype in struct_ty._fields_:
if fname == tp_as_name:
cfunc_t = ftype
tp_as_ptr = getattr(tyobj, tp_as_name)
cfunc = cfunc_t(wrapper)
orig_mp = ctypes.cast(tp_as_ptr, ctypes.c_void_p)
orig_mp_funcs.append(orig_mp)
setattr(tyobj, tp_as_name, cfunc)
cls_dict = proxy_builtin(base_cls)
orig_gm = cls_dict.get('__new__', None)
orig_gm_funcs.append(orig_gm)
if orig_gm is not None:
cls_dict['__new__'] = wrapper
return orig_mp_funcs, orig_gm_funcs
def revert_new(base_cls, func):
cls_list = [base_cls]
orig_mp_funcs, orig_gm_funcs = func
for klass, orig_mp in zip(cls_list, orig_mp_funcs):
tp_as_name = "tp_new"
tyobj = PyTypeObject.from_address(id(klass))
struct_ty = PyTypeObject
# find the C function type
for fname, ftype in struct_ty._fields_:
if fname == tp_as_name:
cfunc_t = ftype
orig = ctypes.cast(orig_mp, cfunc_t)
setattr(tyobj, tp_as_name, orig)
_decref(tyobj)
for klass, orig_gm in zip(cls_list, orig_gm_funcs):
if orig_gm is not None:
cls_dict = proxy_builtin(klass)
cls_dict['__new__'] = orig_gm
def patch_getitem(base_cls, func):
assert callable(func)
@wraps(func)
def wrapper(*args, **kwargs):
"""
This wrapper returns the address of the resulting object as a
python integer which is then converted to a pointer by ctypes
"""
try:
return func(*args, **kwargs)
except NotImplementedError:
return NotImplementedRet
cls_list = list(base_cls.__bases__) + [base_cls] + base_cls.__subclasses__()
orig_mp_funcs = []
orig_gm_funcs = []
for klass in cls_list:
tp_as_name, impl_method = "tp_as_mapping", "mp_subscript"
tyobj = PyTypeObject.from_address(id(klass))
_incref(tyobj)
struct_ty = PyMappingMethods
tp_as_ptr = getattr(tyobj, tp_as_name)
if not tp_as_ptr:
# allocate new array
tp_as_obj = struct_ty()
tp_as_new_ptr = ctypes.cast(ctypes.addressof(tp_as_obj), ctypes.POINTER(struct_ty))
setattr(tyobj, tp_as_name, tp_as_new_ptr)
tp_as = tp_as_ptr[0]
# find the C function type
for fname, ftype in struct_ty._fields_:
if fname == impl_method:
cfunc_t = ftype
cfunc = cfunc_t(wrapper)
orig_mp = ctypes.cast(getattr(tp_as, impl_method), ctypes.c_void_p)
orig_mp_funcs.append(orig_mp)
setattr(tp_as, impl_method, cfunc)
for klass in cls_list:
cls_dict = proxy_builtin(klass)
orig_gm = cls_dict.get('__getitem__', None)
orig_gm_funcs.append(orig_gm)
if orig_gm is not None:
cls_dict['__getitem__'] = wrapper
return orig_mp_funcs, orig_gm_funcs
def revert_getitem(base_cls, func):
cls_list = list(base_cls.__bases__) + [base_cls]
orig_mp_funcs, orig_gm_funcs = func
for klass, orig_mp in zip(cls_list, orig_mp_funcs):
tp_as_name, impl_method = "tp_as_mapping", "mp_subscript"
tyobj = PyTypeObject.from_address(id(klass))
struct_ty = PyMappingMethods
tp_as_ptr = getattr(tyobj, tp_as_name)
tp_as = tp_as_ptr[0]
# find the C function type
for fname, ftype in struct_ty._fields_:
if fname == impl_method:
cfunc_t = ftype
orig = ctypes.cast(orig_mp, cfunc_t)
setattr(tp_as, impl_method, orig)
_decref(tyobj)
for klass, orig_gm in zip(cls_list, orig_gm_funcs):
if orig_gm is not None:
cls_dict = proxy_builtin(klass)
cls_dict['__getitem__'] = orig_gm