def build_model()

in fairseq/models/bert_seq2seq.py [0:0]


    def build_model(cls, args, task):
        """Build a new model instance."""

        #for ds in task.datasets.values():
        #    ds.target_is_source = True

        # make sure all arguments are present in older models
        base_architecture(args)
        if not hasattr(args, 'max_source_positions'):
            args.max_source_positions = 1024
        if not hasattr(args, 'max_target_positions'):
            args.max_target_positions = 1024

        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary

        def build_embedding(dictionary, embed_dim, is_encoder, path=None):

            if path is not None:
                if path.startswith('elmo:'):
                    lm_path = path[5:]
                    task = LanguageModelingTask(args, dictionary, dictionary)
                    models, _ = utils.load_ensemble_for_inference([lm_path], task, {'remove_head': True})
                    assert len(models) == 1, 'ensembles are currently not supported for elmo embeddings'

                    embedder = ElmoTokenEmbedder(models[0], dictionary.eos(), dictionary.pad(), add_bos=is_encoder,
                                                 remove_bos=is_encoder, combine_tower_states=is_encoder,
                                                 projection_dim=embed_dim, add_final_predictive=is_encoder,
                                                 add_final_context=is_encoder)
                    return embedder, 1
                elif path.startswith('bilm:'):
                    lm_path = path[5:]
                    task = LanguageModelingTask(args, dictionary, dictionary)
                    models, _ = utils.load_ensemble_for_inference(
                        [lm_path],
                        task,
                        {'remove_head': True,
                         'dropout': args.bilm_model_dropout,
                         'attention_dropout': args.bilm_attention_dropout,
                         'relu_dropout': args.bilm_relu_dropout, })
                    assert len(models) == 1, 'ensembles are currently not supported for elmo embeddings'

                    return BILMEmbedder(models[0], args, args.encoder_embed_dim) if is_encoder \
                        else LMEmbedder(models[0], args.decoder_embed_dim)

            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()
            emb = nn.Embedding(num_embeddings, embed_dim, padding_idx)
            # if provided, load from preloaded dictionaries
            if path:
                embed_dict = utils.parse_embedding(path)
                utils.load_embedding(embed_dict, dictionary, emb)
            return emb

        if args.share_all_embeddings:
            if src_dict != tgt_dict:
                raise RuntimeError('--share-all-embeddings requires a joined dictionary')
            if args.encoder_embed_dim != args.decoder_embed_dim:
                raise RuntimeError(
                    '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
            if args.decoder_embed_path and (
                    args.decoder_embed_path != args.encoder_embed_path):
                raise RuntimeError('--share-all-embeddings not compatible with --decoder-embed-path')
            encoder_embed_tokens = build_embedding(
                src_dict, args.encoder_embed_dim, is_encoder=True, path=args.encoder_embed_path
            )
            decoder_embed_tokens = encoder_embed_tokens
            args.share_decoder_input_output_embed = True
        else:
            encoder_embed_tokens = build_embedding(
                src_dict, args.encoder_embed_dim, is_encoder=True, path=args.encoder_embed_path
            )
            decoder_embed_tokens = build_embedding(
                tgt_dict, args.decoder_embed_dim, is_encoder=False, path=args.decoder_embed_path
            )

        encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens, args.encoder_embed_scale)
        decoder = SelfTransformerDecoder(args, tgt_dict, decoder_embed_tokens, args.decoder_embed_scale)
        return Transformer_nonautoregressive(encoder, decoder)