core/maxframe/learn/contrib/llm/core.py (47 lines of code) (raw):
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
import numpy as np
import pandas as pd
from ....core.entity.output_types import OutputType
from ....core.operator.base import Operator
from ....core.operator.core import TileableOperatorMixin
from ....dataframe.core import SERIES_TYPE
from ....dataframe.operators import DataFrameOperatorMixin
from ....dataframe.utils import parse_index
from ....serialization.serializables.core import Serializable
from ....serialization.serializables.field import AnyField, DictField, StringField
class LLM(Serializable):
name = StringField("name", default=None)
def validate_params(self, params: Dict[str, Any]):
pass
class LLMTaskOperator(Operator, DataFrameOperatorMixin):
task = AnyField("task", default=None)
model = AnyField("model", default=None)
params = DictField("params", default=None)
running_options: Dict[str, Any] = DictField("running_options", default=None)
def __init__(self, output_types=None, **kw):
if output_types is None:
output_types = [OutputType.dataframe]
super().__init__(_output_types=output_types, **kw)
def get_output_dtypes(self) -> Dict[str, np.dtype]:
raise NotImplementedError
def __call__(self, data, index=None):
outputs = self.get_output_dtypes()
col_name = list(outputs.keys())
columns = parse_index(pd.Index(col_name), store_data=True)
out_dtypes = pd.Series(list(outputs.values()), index=col_name)
index_value = index or (
parse_index(pd.RangeIndex(-1), data)
if isinstance(data, SERIES_TYPE)
else data.index_value
)
return self.new_dataframe(
inputs=[data],
shape=(np.nan, len(col_name)),
index_value=index_value,
columns_value=columns,
dtypes=out_dtypes,
)
class LLMTextGenOperator(LLMTaskOperator, TileableOperatorMixin):
prompt_template = AnyField("prompt_template", default=None)
def get_output_dtypes(self) -> Dict[str, np.dtype]:
return {"response": np.dtype("O"), "success": np.dtype("bool")}