def create_text_decoder()

in src/pixparse/models/text_decoder_hf.py [0:0]


def create_text_decoder(cfg: TextDecoderCfg) -> transformers.BartForCausalLM:  # FIXME for type hints
    assert cfg.name

    config = transformers.AutoConfig.from_pretrained(cfg.name)
    config.add_cross_attention = True
    if False:  # FIXME this were set in Donut but missed in first pass, should compare
        config.is_encoder_decoder = False
        config.scale_embedding = True
        config.add_final_layer_norm = True
    if cfg.num_decoder_layers is not None:
        config.decoder_layers = cfg.num_decoder_layers
    if cfg.max_length is not None:
        config.max_position_embeddings = cfg.max_length
    #config.vocab_size =   # FIXME set vocab size here or rely on model resize when tokens added?

    if cfg.pretrained:
        model = transformers.AutoModelForCausalLM.from_pretrained(
            cfg.name,
            config=config,
        )
    else:
        model = transformers.AutoModelForCausalLM.from_config(
            config,
        )
    # TODO Following is the donut hack. Unused without generate().
    # model.model.decoder.embed_tokens.padding_idx = cfg.pad_token_id

    return model