in notebooks/src/code/train.py [0:0]
def get_model(model_args, data_args):
"""Load pre-trained Config, Model and Tokenizer"""
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=data_args.num_labels,
label2id={str(i): i for i in range(data_args.num_labels)},
id2label={i: str(i) for i in range(data_args.num_labels)},
finetuning_task=data_args.task_name,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
# Potentially unnecessary extra kwargs for LayoutLM:
max_position_embeddings=data_args.max_seq_length, # TODO: VALIDATE THIS
max_2d_position_embeddings=2 * data_args.max_seq_length,
)
tokenizer_name_or_path = (
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path
)
if config.model_type in {"gpt2", "roberta"}:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=True,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
add_prefix_space=True,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=True,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
if data_args.task_name == "ner":
model = AutoModelForTokenClassification.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
elif data_args.task_name == "mlm":
model = AutoModelForMaskedLM.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
else:
raise ValueError(
f"Unknown data_args.task_name '{data_args.task_name}' not in ('mlm', 'ner')"
)
return config, model, tokenizer