chatlearn/utils/utils.py (296 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """utils""" import ast import inspect import socket import subprocess import textwrap import time import math import copy import concurrent.futures from contextlib import closing from types import SimpleNamespace import pynvml import numpy as np import torch import torch.nn.functional as F from chatlearn.utils.logger import logger def get_attributes(cls): """Get attributes from class.""" return [(name, attr) for name, attr in inspect.getmembers(cls) if not (name.startswith('_')) and (not callable(attr))] def parse_function_args(func): args = [] def parse_func_args(node): for argument in node.args.args: args.append(argument.arg) node_iter = ast.NodeVisitor() node_iter.visit_FunctionDef = parse_func_args code = textwrap.dedent(inspect.getsource(func)) node_iter.visit(ast.parse(code)) return args def get_return_lines(node): for line in node.body: if isinstance(line, ast.Return): return line for line in node.body: if isinstance(line, ast.If): return get_return_lines(line) def get_return_value_num(ret): if isinstance(ret.value, ast.Name): return 1 elif isinstance(ret.value, ast.Tuple): return len(ret.value.elts) elif isinstance(ret.value, ast.Call): raise RuntimeError("current do not support nested call in return") def parse_function_return_num(func): results = [] def parse_func_return(node): ret = get_return_lines(node) return_num = 0 if ret is not None: return_num = get_return_value_num(ret) results.append(return_num) node_iter = ast.NodeVisitor() node_iter.visit_FunctionDef = parse_func_return code = textwrap.dedent(inspect.getsource(func)) node_iter.visit(ast.parse(code)) return results[0] def get_host_addr(): """ get ip address in current node """ hostname = socket.gethostname() ip_addr = socket.gethostbyname(hostname) return ip_addr def get_free_port(): """ find a free port """ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as handle: handle.bind(('', 0)) handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return handle.getsockname()[1] def split_index(length, num_splits): # Calculate the size of each split size = length // num_splits remainder = length % num_splits # Initialize an empty list for indices indices = [] # Loop over the number of splits and append indices start = 0 end = 0 for _ in range(num_splits): end += size if remainder > 0: end += 1 remainder -= 1 indices.append((start, end)) start = end # Return the list of indices return indices def to_device(device, args): """ Convert args to device recursively Args: device: gpu/cpu args: args to be converted """ if isinstance(args, (list, tuple)): args = type(args)(to_device(device, arg) for arg in args) elif isinstance(args, dict): for key, value in args.items(): args[key] = to_device(device, value) elif isinstance(args, torch.Tensor): args = args.to(device) return args def get_or_cache(cache, key, func, *args, **kwargs): """ get results if cached otherwise call the func to get the results, and cache the results """ if key in cache: return cache[key] res = func(*args, **kwargs) cache[key] = res return res def flatten(nested_list): flat = [] for elem in nested_list: if isinstance(elem, list): flat.extend(flatten(elem)) else: flat.append(elem) return flat def get_indent_count(string): count = 0 for s in string: if s == ' ': count += 1 else: return count def detect_and_insert_code(lines, pattern, new_code, additional_indent=0, line_offset=0, replace=False): """ Insert new_code above the pattern detected """ detected_lines = [(line_number, line) for line_number, line in enumerate(lines) if pattern in line] if not detected_lines: return type_line_number, type_line = detected_lines[0] indent = get_indent_count(type_line) + additional_indent new_lines = [line for line in new_code.split('\n') if line.strip()] added_lines = [] for line in new_lines: added_lines.append(" "*indent + line) lines = lines[:type_line_number+line_offset - replace] + added_lines + lines[type_line_number+line_offset:] return lines def detect_and_insert_code_to_func(source_code, pattern, new_code, additional_indent=0, line_offset=0, replace=False): lines = source_code.split('\n') lines = detect_and_insert_code(lines, pattern, new_code, additional_indent, line_offset, replace) if lines is None: return indent = get_indent_count(lines[0]) lines = [line[indent:] for line in lines] return '\n'.join(lines) def execute(cmd, check=False, retry=1): """ Execute cmd in shell Args: check: if returncode is non-zero, raise error """ ret = subprocess.run(cmd, shell=True, capture_output=True, text=True, check=check) state = ret.returncode == 0 msg = ret.stdout if state else ret.stderr if not state: logger.warning(f"execute {cmd} got error {msg}") if retry > 1: logger.warning(f"retry {cmd} ...") time.sleep(1) return execute(cmd, check, retry-1) return state, msg def is_connection_refused(msg): keywords = ["StatusCode.UNAVAILABLE", "Connection refused", "failed to connect to all addresses"] return any(keyword in msg for keyword in keywords) def get_ray_status(): cluster_state, msg = execute("ray status", retry=3) if cluster_state: return True, None elif is_connection_refused(msg): return False, msg # unknown msg return True, msg def get_full_proc_memory_info(prefix): torch.cuda.synchronize() s = prefix + ': ' s += f'memory allocated: {torch.cuda.memory_allocated() / (1 << 30):.2f} GiB, ' \ f'memory reserved: {torch.cuda.memory_reserved() / (1 << 30):.2f} GiB, ' \ f'proc memory usage: {nvml_proc_memory_info()}' return s def nvml_proc_memory_info(): pynvml.nvmlInit() s = '' for dev_id in range(pynvml.nvmlDeviceGetCount()): handle = pynvml.nvmlDeviceGetHandleByIndex(dev_id) mem_str = ' | '.join([f'(pid {proc.pid}: {proc.usedGpuMemory / (1 << 30):.2f} GiB)' \ for proc in pynvml.nvmlDeviceGetComputeRunningProcesses(handle)]) s += mem_str break return s def dict_to_simplenamespace(d): for key, value in d.items(): if isinstance(value, dict): d[key] = dict_to_simplenamespace(value) return SimpleNamespace(**d) def get_use_legacy_models(model_args): if isinstance(model_args, dict): use_legacy_models = model_args.get("use_legacy_models", None) else: use_legacy_models = getattr(model_args, "use_legacy_models", None) if use_legacy_models is None: raise RuntimeError("Please specify use_legacy_models (True or False), but not None.") return use_legacy_models def execute_in_parallel(function, arguments): if len(arguments) == 1: return function(*arguments[0]) results = [] with concurrent.futures.ThreadPoolExecutor() as executor: # Using list comprehension to handle the results futures = [executor.submit(function, *args) for args in arguments] for _future in concurrent.futures.as_completed(futures): results.append(_future.result()) return results def multi_thread_data_processing(num_threads: int, all_data: list, process_one_data, fn_args: list): num_data = len(all_data) if num_data == 0: return [] # reduce num_threads if data amount is little if num_data < num_threads: num_threads = num_data assert num_threads > 0, "Get num_threads <= 0. Expect to be a positive number." result = list(range(num_data)) data_size_per_thread = math.ceil(num_data / num_threads) thread_args = [(i, all_data[i * data_size_per_thread : min((i+1) * data_size_per_thread, num_data)]) for i in range(num_threads)] def thread_fn(thread_id: int, pending_data): offset = thread_id * data_size_per_thread for i, data in enumerate(pending_data): result[offset + i] = process_one_data(data, *fn_args) execute_in_parallel(thread_fn, thread_args) return result def regroup_by_concat_along_batch(tensors): batched = {} if tensors[0] is None: return batched for key in tensors[0].keys(): to_batch = [results[key] for results in tensors] if isinstance(to_batch[0], torch.Tensor): if len(to_batch[0].shape) == 2: max_dim_1 = max([ele.shape[1] for ele in to_batch]) # pylint: disable=consider-using-generator pad_value = 0.0 if to_batch[0].dtype in [torch.float32, torch.float16, torch.bfloat16] else 0 value = [ F.pad( ele, (0, max_dim_1 - ele.shape[1]), value=pad_value, ) for ele in to_batch ] batched[key] = torch.vstack(value) elif len(to_batch[0].shape) == 1: batched[key] = torch.concat(to_batch) else: raise RuntimeError(f"unsupported shape for in_queue rebatching. expect 1 or 2. while {to_batch[0].shape}") elif isinstance(to_batch[0], list): batched[key] = [] for seq in to_batch: batched[key].extend(seq) else: raise Exception(f"unknown types key: {key} and {type(to_batch[0])} to concat : {to_batch[0]}") return batched def slice_by_index_along_batch(batched_input, index): start = index[0] offset = index[1] batched = {} for key in batched_input.keys(): if isinstance(batched_input[key], torch.Tensor): batched[key] = batched_input[key][start::offset,...] elif isinstance(batched_input[key], list): batched[key] = batched_input[key][start::offset] return batched def listdict_to_dictlist(ld, list_extend=True): ''' [{k1: v11, k2: v2}, {k1: v12, k2: v2},....] => {k1: [v11, v12..], k2: [v21, v22...]} if v11 is list then k1: v11 + v12 :param ld: :return: ''' res = copy.deepcopy(ld[0]) for res_key, v in res.items(): if list_extend and isinstance(res[res_key], list): continue res[res_key] = [v] for d in ld[1:]: for key, v in d.items(): if list_extend and isinstance(d[key], list): res[key].extend(v) else: res[key].append(v) return res def map_metrics(metric_list): mapped_metrics = {} for metrics in metric_list: for key, value in metrics.items(): if key in mapped_metrics: mapped_metrics[key].append(value) else: mapped_metrics[key] = [value] return mapped_metrics def reduce_metrics(merged_metrics): # [TODO:baodong.lh] support custom_op like min, max to reduce metrics reduced_metrics = {} for key, value_list in merged_metrics.items(): if isinstance(value_list[0], torch.Tensor): value = torch.mean(torch.Tensor(value_list)) else: value = np.mean(value_list) reduced_metrics[key] = value return reduced_metrics def map_reduce_metrics(metric_list): # [TODO:baodong.lh] imporve performance by distributing the task to per-replica # sanity check assert isinstance(metric_list, list) if len(metric_list) == 0: return {} first_metric_len = len(metric_list[0]) for i, metric in enumerate(metric_list): if len(metric) != first_metric_len: logger.info( f"WARNING! length of metrics are not the same for {i}-th metric ({len(metric)}) " f"and the first one ({first_metric_len})! This is weird and please check!" ) mapped_metrics = map_metrics(metric_list) reduced_metrics = reduce_metrics(mapped_metrics) return reduced_metrics