core/maxframe/learn/contrib/models.py (50 lines of code) (raw):
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ... import opcodes
from ...core import ENTITY_TYPE, OutputType
from ...core.operator import ObjectOperator, ObjectOperatorMixin
from ...serialization.serializables import (
AnyField,
DictField,
FunctionField,
TupleField,
)
from ...utils import find_objects, replace_objects
class ModelDataSource(ObjectOperator, ObjectOperatorMixin):
_op_type_ = opcodes.MODEL_DATA_SOURCE
data = AnyField("data")
def __call__(self, model_cls):
self._output_types = [OutputType.object]
return self.new_tileable(None, object_class=model_cls)
class ModelApplyChunk(ObjectOperator, ObjectOperatorMixin):
_op_module_ = "maxframe.learn.contrib.models"
_op_type_ = opcodes.APPLY_CHUNK
func = FunctionField("func")
args = TupleField("args")
kwargs = DictField("kwargs")
def __init__(self, output_types=None, **kwargs):
if not isinstance(output_types, (tuple, list)):
output_types = [output_types]
self._output_types = list(output_types)
super().__init__(**kwargs)
def _set_inputs(self, inputs):
super()._set_inputs(inputs)
old_inputs = find_objects(self.args, ENTITY_TYPE) + find_objects(
self.kwargs, ENTITY_TYPE
)
mapping = {o: n for o, n in zip(old_inputs, self._inputs[1:])}
self.args = replace_objects(self.args, mapping)
self.kwargs = replace_objects(self.kwargs, mapping)
@property
def output_limit(self) -> int:
return len(self._output_types)
def __call__(self, t, output_kws, args=None, **kwargs):
self.args = args or ()
self.kwargs = kwargs
inputs = (
[t]
+ find_objects(self.args, ENTITY_TYPE)
+ find_objects(self.kwargs, ENTITY_TYPE)
)
return self.new_tileables(inputs, kws=output_kws)
def to_remote_model(model, model_cls):
op = ModelDataSource(data=model)
return op(model_cls)