in chatlearn/utils/vllm_utils.py [0:0]
def _load_base_checkpoint(load_dir, rank0=False):
""" Load the base state_dict from the given directory
If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
"""
# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(load_dir)
# If no tracker file, return nothing
if not os.path.isfile(tracker_filename):
if not rank0:
print_rank_0('WARNING: could not find the metadata file {} '.format(
tracker_filename))
print_rank_0(' will not load any checkpoints and will start from '
'random')
return None, "", False
# Otherwise, read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration, release = read_metadata(tracker_filename)
# Checkpoint.
if rank0:
checkpoint_name = find_checkpoint_rank_0(load_dir, iteration, release)
else:
checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
if release:
print_rank_0(f' loading release checkpoint from {load_dir}')
else:
print_rank_0(f' loading checkpoint from {load_dir} at iteration {iteration}')
if isinstance(checkpoint_name, tuple):
checkpoint_name = checkpoint_name[0]
# Load the checkpoint.
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
except ModuleNotFoundError:
try:
from megatron.fp16_deprecated import loss_scaler # pylint: disable=import-outside-toplevel,unused-import
# For backward compatibility.
if not rank0:
print_rank_0(' > deserializing using the old code structure ...')
sys.modules['fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
state_dict = torch.load(checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None)
sys.modules.pop('megatron.fp16.loss_scaler', None)
except ImportError:
from megatron.legacy.fp16_deprecated import loss_scaler # pylint: disable=import-outside-toplevel,unused-import
# For backward compatibility.
if not rank0:
print_rank_0(' > deserializing using the old code structure ...')
sys.modules['fp16.loss_scaler'] = sys.modules[
'megatron.legacy.fp16_deprecated.loss_scaler']
sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
'megatron.legacy.fp16_deprecated.loss_scaler']
sys.modules['megatron.model'] = sys.modules['megatron.legacy.model']
state_dict = torch.load(checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None)
sys.modules.pop('megatron.fp16.loss_scaler', None)
sys.modules.pop('megatron.model', None)
except BaseException as e:
print_rank_0('could not load the checkpoint')
print_rank_0(e)
sys.exit()
return state_dict, checkpoint_name, release