sockeye/train.py [19:128]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
initial_setup.handle_env_cli_arg()

import argparse
import logging
import os
import shutil
import sys
import tempfile
from contextlib import ExitStack
from typing import cast, Callable, Optional, Dict, List, Tuple

import mxnet as mx
from mxnet import gluon
from mxnet import amp

from . import arguments
from . import checkpoint_decoder
from . import constants as C
from . import data_io
from . import decoder
from . import encoder
from . import horovod_mpi
from . import layers
from . import loss
from . import lr_scheduler
from . import model
from . import training
from . import transformer
from . import utils
from . import vocab
from .config import Config
from .log import setup_main_logger
from .optimizers import OptimizerConfig
from .utils import check_condition

# Temporary logger, the real one (logging to a file probably, will be created in the main function)
logger = logging.getLogger(__name__)


def none_if_negative(val):
    return None if val < 0 else val


def _list_to_tuple(v):
    """Convert v to a tuple if it is a list."""
    if isinstance(v, list):
        return tuple(v)
    return v


def _dict_difference(dict1: Dict, dict2: Dict):
    diffs = set()
    for k, v in dict1.items():
        # Note: A list and a tuple with the same values is considered equal
        # (this is due to json deserializing former tuples as list).
        if k not in dict2 or _list_to_tuple(dict2[k]) != _list_to_tuple(v):
            diffs.add(k)
    return diffs


def check_arg_compatibility(args: argparse.Namespace):
    """
    Check if some arguments are incompatible with each other.

    :param args: Arguments as returned by argparse.
    """

    # Require at least one stopping criteria
    check_condition(any((args.max_samples,
                         args.max_updates,
                         args.max_seconds,
                         args.max_checkpoints,
                         args.max_num_epochs,
                         args.max_num_checkpoint_not_improved)),
                    'Please specify at least one stopping criteria: --max-samples --max-updates --max-checkpoints '
                    '--max-num-epochs --max-num-checkpoint-not-improved')

    # Check and possibly adapt the parameters for source factors
    n_source_factors = len(args.validation_source_factors)
    if len(args.source_factors_combine) > 1:
        check_condition(n_source_factors == len(args.source_factors_combine),
                        'The number of combination strategies for source '
                        'factors does not match the number of source factors.')
    else:
        # Length 1: expand the list to the appropriate length
        args.source_factors_combine = args.source_factors_combine * n_source_factors
    if len(args.source_factors_share_embedding) > 1:
        check_condition(n_source_factors == len(args.source_factors_share_embedding),
                        'The number of vocabulary sharing flags for source '
                        'factors does not match the number of source factors.')
    else:
        # Length 1: expand the list to the appropriate length
        args.source_factors_share_embedding = args.source_factors_share_embedding * n_source_factors

    # Check and possibly adapt the parameters for target factors
    n_target_factors = len(args.validation_target_factors)
    if len(args.target_factors_combine) > 1:
        check_condition(n_target_factors == len(args.target_factors_combine),
                        'The number of combination strategies for target '
                        'factors does not match the number of target factors.')
    else:
        # Length 1: expand the list to the appropriate length
        args.target_factors_combine = args.target_factors_combine * n_target_factors
    if len(args.target_factors_share_embedding) > 1:
        check_condition(n_target_factors == len(args.target_factors_share_embedding),
                        'The number of vocabulary sharing flags for target '
                        'factors does not match the number of target factors.')
    else:
        # Length 1: expand the list to the appropriate length
        args.target_factors_share_embedding = args.target_factors_share_embedding * n_target_factors
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



sockeye/train_pt.py [19:125]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
initial_setup.handle_env_cli_arg()

import argparse
import logging
import os
import shutil
import sys
import tempfile
from typing import cast, Callable, Optional, Dict, List, Tuple

import torch
import torch.distributed
import torch.distributed.elastic.multiprocessing.errors

from . import arguments
from . import checkpoint_decoder_pt
from . import constants as C
from . import data_io_pt
from . import encoder_pt
from . import layers_pt
from . import loss_pt
from . import lr_scheduler
from . import model_pt
from . import optimizers
from . import training_pt
from . import transformer_pt
from . import utils
from . import vocab
from .config import Config
from .log import setup_main_logger
from .utils import check_condition

# Temporary logger, the real one (logging to a file probably, will be created in the main function)
logger = logging.getLogger(__name__)


def none_if_negative(val):
    return None if val < 0 else val


def _list_to_tuple(v):
    """Convert v to a tuple if it is a list."""
    if isinstance(v, list):
        return tuple(v)
    return v


def _dict_difference(dict1: Dict, dict2: Dict):
    diffs = set()
    for k, v in dict1.items():
        # Note: A list and a tuple with the same values is considered equal
        # (this is due to json deserializing former tuples as list).
        if k not in dict2 or _list_to_tuple(dict2[k]) != _list_to_tuple(v):
            diffs.add(k)
    return diffs


def check_arg_compatibility(args: argparse.Namespace):
    """
    Check if some arguments are incompatible with each other.

    :param args: Arguments as returned by argparse.
    """

    # Require at least one stopping criteria
    check_condition(any((args.max_samples,
                         args.max_updates,
                         args.max_seconds,
                         args.max_checkpoints,
                         args.max_num_epochs,
                         args.max_num_checkpoint_not_improved)),
                    'Please specify at least one stopping criteria: --max-samples --max-updates --max-checkpoints '
                    '--max-num-epochs --max-num-checkpoint-not-improved')

    # Check and possibly adapt the parameters for source factors
    n_source_factors = len(args.validation_source_factors)
    if len(args.source_factors_combine) > 1:
        check_condition(n_source_factors == len(args.source_factors_combine),
                        'The number of combination strategies for source '
                        'factors does not match the number of source factors.')
    else:
        # Length 1: expand the list to the appropriate length
        args.source_factors_combine = args.source_factors_combine * n_source_factors
    if len(args.source_factors_share_embedding) > 1:
        check_condition(n_source_factors == len(args.source_factors_share_embedding),
                        'The number of vocabulary sharing flags for source '
                        'factors does not match the number of source factors.')
    else:
        # Length 1: expand the list to the appropriate length
        args.source_factors_share_embedding = args.source_factors_share_embedding * n_source_factors

    # Check and possibly adapt the parameters for target factors
    n_target_factors = len(args.validation_target_factors)
    if len(args.target_factors_combine) > 1:
        check_condition(n_target_factors == len(args.target_factors_combine),
                        'The number of combination strategies for target '
                        'factors does not match the number of target factors.')
    else:
        # Length 1: expand the list to the appropriate length
        args.target_factors_combine = args.target_factors_combine * n_target_factors
    if len(args.target_factors_share_embedding) > 1:
        check_condition(n_target_factors == len(args.target_factors_share_embedding),
                        'The number of vocabulary sharing flags for target '
                        'factors does not match the number of target factors.')
    else:
        # Length 1: expand the list to the appropriate length
        args.target_factors_share_embedding = args.target_factors_share_embedding * n_target_factors
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



