chatlearn/runtime/utils.py (54 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """runtime utils""" import ast import textwrap import inspect from collections import defaultdict def encode_data(mb, data): return {"iter": mb, "data": data} def decode_data(data): mb = data["iter"] data = data["data"] return mb, data def parse_assign_target(line): targets = [] for target in line.targets: targets.append(target.id) return targets def parse_expr(line): func = line.value.func func_name = func.attr func_args = [arg.id for arg in line.value.args] if isinstance(func.value, ast.Name): model_name = func.value.id else: model_name = func.value.attr return func_name, model_name, func_args class FlowParser: """Flow Parser""" def __init__(self): self.model_to_call_funcs = defaultdict(list) def parse_assign(self, line): func_name, model_name, _ = parse_expr(line) model = self.global_models[model_name] self.model_to_call_funcs[model].append(func_name) def visit_func(self, node): for line in node.body: if isinstance(line, (ast.Assign, ast.Expr)): self.parse_assign(line) elif isinstance(line, ast.With): for line0 in line.body: if isinstance(line0, (ast.Assign, ast.Expr)): self.parse_assign(line0) def parse(self, func): closure_vars = inspect.getclosurevars(func) self.global_models = {} if closure_vars.globals: self.global_models.update(closure_vars.globals) if closure_vars.nonlocals: self.global_models.update(closure_vars.nonlocals) node_iter = ast.NodeVisitor() node_iter.visit_FunctionDef = self.visit_func if isinstance(func, str): code = textwrap.dedent(func) else: code = textwrap.dedent(inspect.getsource(func)) node_iter.visit(ast.parse(code)) return self.model_to_call_funcs