def _upgrade_state_dict()

in fairseq/checkpoint_utils.py [0:0]


def _upgrade_state_dict(state):
    """Helper for upgrading old model checkpoints."""
    from fairseq import models, registry, tasks

    # add optimizer_history
    if 'optimizer_history' not in state:
        state['optimizer_history'] = [
            {
                'criterion_name': 'CrossEntropyCriterion',
                'best_loss': state['best_loss'],
            },
        ]
        state['last_optimizer_state'] = state['optimizer']
        del state['optimizer']
        del state['best_loss']
    # move extra_state into sub-dictionary
    if 'epoch' in state and 'extra_state' not in state:
        state['extra_state'] = {
            'epoch': state['epoch'],
            'batch_offset': state['batch_offset'],
            'val_loss': state['val_loss'],
        }
        del state['epoch']
        del state['batch_offset']
        del state['val_loss']
    # reduce optimizer history's memory usage (only keep the last state)
    if 'optimizer' in state['optimizer_history'][-1]:
        state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer']
        for optim_hist in state['optimizer_history']:
            del optim_hist['optimizer']
    # record the optimizer class name
    if 'optimizer_name' not in state['optimizer_history'][-1]:
        state['optimizer_history'][-1]['optimizer_name'] = 'FairseqNAG'
    # move best_loss into lr_scheduler_state
    if 'lr_scheduler_state' not in state['optimizer_history'][-1]:
        state['optimizer_history'][-1]['lr_scheduler_state'] = {
            'best': state['optimizer_history'][-1]['best_loss'],
        }
        del state['optimizer_history'][-1]['best_loss']
    # keep track of number of updates
    if 'num_updates' not in state['optimizer_history'][-1]:
        state['optimizer_history'][-1]['num_updates'] = 0
    # old model checkpoints may not have separate source/target positions
    if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'):
        state['args'].max_source_positions = state['args'].max_positions
        state['args'].max_target_positions = state['args'].max_positions
    # use stateful training data iterator
    if 'train_iterator' not in state['extra_state']:
        state['extra_state']['train_iterator'] = {
            'epoch': state['extra_state']['epoch'],
            'iterations_in_epoch': state['extra_state'].get('batch_offset', 0),
        }
    # default to translation task
    if not hasattr(state['args'], 'task'):
        state['args'].task = 'translation'

    def set_defaults(cls):
        if not hasattr(cls, 'add_args'):
            return
        parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, allow_abbrev=False)
        cls.add_args(parser)
        # copied from argparse.py:
        defaults = argparse.Namespace()
        for action in parser._actions:
            if action.dest is not argparse.SUPPRESS:
                if not hasattr(defaults, action.dest):
                    if action.default is not argparse.SUPPRESS:
                        setattr(defaults, action.dest, action.default)
        for key, default_value in vars(defaults).items():
            if not hasattr(state['args'], key):
                setattr(state['args'], key, default_value)

    # set any missing default values in the task, model or other registries
    set_defaults(tasks.TASK_REGISTRY[state['args'].task])
    set_defaults(models.ARCH_MODEL_REGISTRY[state['args'].arch])
    for registry_name, REGISTRY in registry.REGISTRIES.items():
        choice = getattr(state['args'], registry_name, None)
        if choice is not None:
            cls = REGISTRY['registry'][choice]
            set_defaults(cls)

    return state