chatlearn/utils/future.py (53 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""get remote object"""
import ray
from chatlearn.utils.logger import logging_tqdm
from chatlearn.utils.utils import flatten
def check_nested_2_level_list(refs):
"""
Checks if a list is a nested list with a nested level of 2.
e.g.
[[ref0, ref1], [ref2, ref3]] returns True, [2, 2]
[ref0, ref1] returns False, -1
[[ref0], [ref1, ref2]] returns True, [1, 2]
Returns a tuple containing two elements:
- A boolean value indicating if the list is a nested 2-level list
- A list of integers containing the length of each sublist
"""
sublist_lens = []
for sublist in refs:
if isinstance(sublist, list):
if len(sublist) == 0:
sublist_lens.append(0)
else:
if isinstance(sublist[0], ray.ObjectRef):
sublist_lens.append(len(sublist))
else:
return False, None
else:
return False, None
return True, sublist_lens
def wait(refs, desc=None, return_output=False):
"""
wait until all computation finish
"""
if isinstance(refs, ray.ObjectRef):
ray.get(refs)
return
if len(refs) == 0:
return
nested2, sublist_lens = check_nested_2_level_list(refs)
refs = flatten(refs)
if desc is not None:
total = len(refs) if not nested2 else len(sublist_lens)
pbar = logging_tqdm(total=total, desc=desc)
i = 0
wait_refs = refs.copy()
while wait_refs:
num_returns = 1 if not nested2 else sublist_lens[i]
done, wait_refs = ray.wait(wait_refs, num_returns=num_returns)
i += 1
if desc is not None:
done_size = len(done) if not nested2 else 1
pbar.update(done_size)
if return_output:
outputs = ray.get(refs)
if desc is not None:
pbar.close()
if return_output:
return outputs
def get(data):
"""get remote data"""
if isinstance(data, (list, tuple)):
dtype = type(data)
ret = dtype(get(item) for item in data)
return ret
if isinstance(data, dict):
return {key: get(value) for key, value in data.items()}
while isinstance(data, ray.ObjectRef):
data = ray.get(data)
return data