python/tvm/ffi/cython/object.pxi (123 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
_CLASS_OBJECT = None
_FUNC_CONVERT_TO_OBJECT = None
def _set_class_object(cls):
global _CLASS_OBJECT
_CLASS_OBJECT = cls
def _set_func_convert_to_object(func):
global _FUNC_CONVERT_TO_OBJECT
_FUNC_CONVERT_TO_OBJECT = func
def __object_repr__(obj):
"""Object repr function that can be overridden by assigning to it"""
return type(obj).__name__ + "(" + obj.__ctypes_handle__().value + ")"
def __object_save_json__(obj):
"""Object repr function that can be overridden by assigning to it"""
raise NotImplementedError("JSON serialization depends on downstream init")
def __object_load_json__(json_str):
"""Object repr function that can be overridden by assigning to it"""
raise NotImplementedError("JSON serialization depends on downstream init")
def __object_dir__(obj):
"""Object dir function that can be overridden by assigning to it"""
return []
def __object_getattr__(obj, name):
"""Object getattr function that can be overridden by assigning to it"""
raise AttributeError()
def _new_object(cls):
"""Helper function for pickle"""
return cls.__new__(cls)
class ObjectGeneric:
"""Base class for all classes that can be converted to object."""
def asobject(self):
"""Convert value to object"""
raise NotImplementedError()
class ObjectRValueRef:
"""Represent an RValue ref to an object that can be moved.
Parameters
----------
obj : tvm.runtime.Object
The object that this value refers to
"""
__slots__ = ["obj"]
def __init__(self, obj):
self.obj = obj
cdef class Object:
"""Base class of all TVM FFI objects.
"""
cdef void* chandle
def __dealloc__(self):
if self.chandle != NULL:
CHECK_CALL(TVMFFIObjectFree(self.chandle))
def __ctypes_handle__(self):
return ctypes_handle(self.chandle)
def __chandle__(self):
cdef uint64_t chandle = <uint64_t>self.chandle
return chandle
def __reduce__(self):
cls = type(self)
return (_new_object, (cls,), self.__getstate__())
def __getstate__(self):
if not self.__chandle__() == 0:
# need to explicit convert to str in case String
# returned and triggered another infinite recursion in get state
return {"handle": str(__object_save_json__(self))}
return {"handle": None}
def __setstate__(self, state):
# pylint: disable=assigning-non-slot, assignment-from-no-return
handle = state["handle"]
if handle is not None:
self.__init_handle_by_constructor__(__object_load_json__, handle)
else:
self.chandle = NULL
def __getattr__(self, name):
try:
return __object_getattr__(self, name)
except AttributeError:
raise AttributeError(f"{type(self)} has no attribute {name}")
def __dir__(self):
return __object_dir__(self)
def __repr__(self):
# make sure repr is a raw string
return str(__object_repr__(self))
def __eq__(self, other):
return self.same_as(other)
def __ne__(self, other):
return not self.__eq__(other)
def __init_handle_by_load_json__(self, json_str):
raise NotImplementedError("JSON serialization depends on downstream init")
def __init_handle_by_constructor__(self, fconstructor, *args):
"""Initialize the handle by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return handle is directly set into the Node object
instead of creating a new Node.
"""
# avoid error raised during construction.
self.chandle = NULL
cdef void* chandle
ConstructorCall(
(<Object>fconstructor).chandle, args, &chandle)
self.chandle = chandle
def same_as(self, other):
"""Check object identity.
Parameters
----------
other : object
The other object to compare against.
Returns
-------
result : bool
The comparison result.
"""
if not isinstance(other, Object):
return False
return self.chandle == (<Object>other).chandle
def __hash__(self):
cdef uint64_t hash_value = <uint64_t>self.chandle
return hash_value
def _move(self):
"""Create an RValue reference to the object and mark the object as moved.
This is a advanced developer API that can be useful when passing an
unique reference to an Object that you no longer needed to a function.
A unique reference can trigger copy on write optimization that avoids
copy when we transform an object.
Note
----
All the reference of the object becomes invalid after it is moved.
Be very careful when using this feature.
Returns
-------
rvalue : The rvalue reference.
"""
return ObjectRValueRef(self)
def __move_handle_from__(self, other):
"""Move the handle from other to self"""
self.chandle = (<Object>other).chandle
(<Object>other).chandle = NULL
class PyNativeObject:
"""Base class of all TVM objects that also subclass python's builtin types."""
__slots__ = []
def __init_tvm_ffi_object_by_constructor__(self, fconstructor, *args):
"""Initialize the internal tvm_ffi_object by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return object is directly set into the object
"""
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
obj.__init_handle_by_constructor__(fconstructor, *args)
self.__tvm_ffi_object__ = obj
"""Maps object type index to its constructor"""
cdef list OBJECT_TYPE = []
"""Maps object type to its type index"""
cdef dict OBJECT_INDEX = {}
def _register_object_by_index(int index, object cls):
"""register object class"""
global OBJECT_TYPE
while len(OBJECT_TYPE) <= index:
OBJECT_TYPE.append(None)
OBJECT_TYPE[index] = cls
OBJECT_INDEX[cls] = index
def _object_type_key_to_index(str type_key):
"""get the type index of object class"""
cdef int32_t tidx
type_key_arg = ByteArrayArg(c_str(type_key))
if TVMFFITypeKeyToIndex(type_key_arg.cptr(), &tidx) == 0:
return tidx
return None
cdef inline object make_ret_object(TVMFFIAny result):
global OBJECT_TYPE
cdef int32_t tindex
cdef object cls
tindex = result.type_index
if tindex < len(OBJECT_TYPE):
cls = OBJECT_TYPE[tindex]
if cls is not None:
if issubclass(cls, PyNativeObject):
obj = Object.__new__(Object)
(<Object>obj).chandle = result.v_obj
return cls.__from_tvm_ffi_object__(cls, obj)
obj = cls.__new__(cls)
else:
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
else:
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
(<Object>obj).chandle = result.v_obj
return obj
_set_class_object(Object)