def create_model_config()

in sockeye/train.py [0:0]


def create_model_config(args: argparse.Namespace,
                        source_vocab_sizes: List[int],
                        target_vocab_sizes: List[int],
                        max_seq_len_source: int,
                        max_seq_len_target: int,
                        config_data: data_io.DataConfig) -> model.ModelConfig:
    """
    Create a ModelConfig from the argument given in the command line.

    :param args: Arguments as returned by argparse.
    :param source_vocab_sizes: The size of the source vocabulary (and source factors).
    :param target_vocab_sizes: The size of the target vocabulary (and target factors).
    :param max_seq_len_source: Maximum source sequence length.
    :param max_seq_len_target: Maximum target sequence length.
    :param config_data: Data config.
    :return: The model configuration.
    """
    num_embed_source, num_embed_target = get_num_embed(args)

    embed_dropout_source, embed_dropout_target = args.embed_dropout
    source_vocab_size, *source_factor_vocab_sizes = source_vocab_sizes
    target_vocab_size, *target_factor_vocab_sizes = target_vocab_sizes

    config_encoder, encoder_num_hidden = create_encoder_config(args, max_seq_len_source, max_seq_len_target,
                                                               num_embed_source)
    config_decoder = create_decoder_config(args, encoder_num_hidden, max_seq_len_source, max_seq_len_target,
                                           num_embed_target)

    source_factor_configs = None
    if len(source_vocab_sizes) > 1:
        source_factors_num_embed = args.source_factors_num_embed
        if not source_factors_num_embed:
            # This happens if the combination method is sum or average. We then
            # set the dimension to num_embed_source for all factors
            logger.info("Setting all source factor embedding sizes to `num_embed` ('%d')",
                        num_embed_source)
            source_factors_num_embed = [num_embed_source] * len(source_factor_vocab_sizes)
        else:
            # Check each individual factor
            for i, combine in enumerate(args.source_factors_combine):
                if combine in [C.FACTORS_COMBINE_SUM, C.FACTORS_COMBINE_AVERAGE]:
                    logger.info("Setting embedding size of source factor %d to `num_embed` ('%d') for %s",
                                i + 1, num_embed_source,
                                "summing" if combine == C.FACTORS_COMBINE_SUM else "averaging")
                    source_factors_num_embed[i] = num_embed_source

        source_factor_configs = [encoder.FactorConfig(size, dim, combine, share) \
                                 for size, dim, combine, share in zip(source_factor_vocab_sizes,
                                                                      source_factors_num_embed,
                                                                      args.source_factors_combine,
                                                                      args.source_factors_share_embedding)]

    target_factor_configs = None
    if len(target_vocab_sizes) > 1:
        target_factors_num_embed = args.target_factors_num_embed
        if not target_factors_num_embed:
            # This happens if the combination method is sum or average. We then
            # set the dimension to num_embed_target for all factors
            logger.info("Setting all target factor embedding sizes to `num_embed` ('%d')",
                        num_embed_target)
            target_factors_num_embed = [num_embed_target] * len(target_factor_vocab_sizes)
        else:
            # Check each individual factor
            for i, combine in enumerate(args.target_factors_combine):
                if combine in [C.FACTORS_COMBINE_SUM, C.FACTORS_COMBINE_AVERAGE]:
                    logger.info("Setting embedding size of target factor %d to `num_embed` ('%d') for %s",
                                i + 1, num_embed_target,
                                "summing" if combine == C.FACTORS_COMBINE_SUM else "averaging")
                    target_factors_num_embed[i] = num_embed_target

        target_factor_configs = [encoder.FactorConfig(size, dim, combine, share) \
                                 for size, dim, combine, share in zip(target_factor_vocab_sizes,
                                                                      target_factors_num_embed,
                                                                      args.target_factors_combine,
                                                                      args.target_factors_share_embedding)]

    allow_sparse_grad = args.update_interval == 1  # sparse embedding gradients do not work with grad_req='add'

    config_embed_source = encoder.EmbeddingConfig(vocab_size=source_vocab_size,
                                                  num_embed=num_embed_source,
                                                  dropout=embed_dropout_source,
                                                  factor_configs=source_factor_configs,
                                                  allow_sparse_grad=allow_sparse_grad)

    config_embed_target = encoder.EmbeddingConfig(vocab_size=target_vocab_size,
                                                  num_embed=num_embed_target,
                                                  dropout=embed_dropout_target,
                                                  factor_configs=target_factor_configs,
                                                  allow_sparse_grad=allow_sparse_grad)

    config_length_task = None
    if args.length_task is not None:
        config_length_task = layers.LengthRatioConfig(num_layers=args.length_task_layers,
                                                      weight=args.length_task_weight)

    model_config = model.ModelConfig(config_data=config_data,
                                     vocab_source_size=source_vocab_size,
                                     vocab_target_size=target_vocab_size,
                                     config_embed_source=config_embed_source,
                                     config_embed_target=config_embed_target,
                                     config_encoder=config_encoder,
                                     config_decoder=config_decoder,
                                     config_length_task=config_length_task,
                                     weight_tying_type=args.weight_tying_type,
                                     lhuc=args.lhuc is not None,
                                     dtype=args.dtype)
    return model_config