torchbenchmark/e2e_models/hf_bert/trainer.py (136 lines of code) (raw):
import torch
import math
import os
from pathlib import Path
from torch.utils.data import DataLoader
from ...util.model import BenchmarkModel
from torchbenchmark.tasks import NLP
from accelerate import Accelerator
from transformers import (
AdamW,
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
default_data_collator,
get_scheduler,
)
from torchbenchmark.util.framework.transformers.text_classification.dataset import prep_dataset, preprocess_dataset, prep_labels
from torchbenchmark.util.framework.transformers.text_classification.args import parse_args
# setup environment variable
CURRENT_DIR = Path(os.path.dirname(os.path.realpath(__file__)))
OUTPUT_DIR = os.path.join(CURRENT_DIR, ".output")
torch.manual_seed(1337)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
class Model(BenchmarkModel):
task = NLP.LANGUAGE_MODELING
def __init__(self, device=None, train_bs=32, task_name="cola"):
super().__init__()
self.device = device
model_name = "bert-base-cased"
max_seq_length = "128"
learning_rate = "2e-5"
num_train_epochs = "3"
# this benchmark runs on a single GPU
cuda_visible_devices = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
output_dir = OUTPUT_DIR
in_arg = ["--model_name_or_path", model_name, "--task_name", task_name,
"--do_train", "--do_eval", "--max_seq_length", max_seq_length,
"--per_device_train_batch_size", str(train_bs),
"--learning_rate", learning_rate,
"--num_train_epochs", num_train_epochs,
"--output_dir", OUTPUT_DIR]
model_args, data_args, training_args = parse_args(in_arg)
# setup other members
self.prep(model_args, data_args, training_args)
def prep(self, model_args, data_args, training_args):
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
accelerator = Accelerator()
accelerator.wait_for_everyone()
raw_datasets = prep_dataset(data_args, training_args)
num_labels, label_list, is_regression = prep_labels(data_args, raw_datasets)
# Load pretrained model and tokenizer
#
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels,
finetuning_task=data_args.task_name,
# cache_dir=model_args.cache_dir,
# revision=model_args.model_revision,
# use_auth_token=True if model_args.use_auth_token else None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
# cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
# revision=model_args.model_revision,
# use_auth_token=True if model_args.use_auth_token else None,
)
model = AutoModelForSequenceClassification.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
# cache_dir=model_args.cache_dir,
# revision=model_args.model_revision,
# use_auth_token=True if model_args.use_auth_token else None,
)
train_dataset, eval_dataset, _predict_dataset = preprocess_dataset(data_args, training_args, config, model, \
tokenizer, raw_datasets, num_labels, label_list, is_regression)
# Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding.
if data_args.pad_to_max_length:
data_collator = default_data_collator
elif training_args.fp16:
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
else:
data_collator = None
train_dataloader = DataLoader(
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=training_args.per_device_train_batch_size)
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=training_args.per_device_eval_batch_size)
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": training_args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=training_args.learning_rate)
# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader
)
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
# shorter in multiprocess)
# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
if training_args.max_steps is None or training_args.max_steps == -1:
training_args.max_steps = training_args.num_train_epochs * num_update_steps_per_epoch
else:
training_args.num_train_epochs = math.ceil(training_args.max_steps / num_update_steps_per_epoch)
training_args.num_train_epochs = int(training_args.num_train_epochs)
lr_scheduler = get_scheduler(
name=training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps,
num_training_steps=training_args.max_steps,
)
# Setup class members
self.training_args = training_args
self.is_regression = is_regression
self.model = model
self.optimizer = optimizer
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader
self.lr_scheduler = lr_scheduler
self.accelerator = accelerator
def get_module(self):
raise NotImplementedError("get_module is not supported by this model")
def train(self, niter=1):
if self.jit:
raise NotImplementedError("JIT is not supported by this model")
if not self.device == "cuda":
raise NotImplementedError("Only CUDA is supported by this model")
assert self.training_args.do_train, "Must train with `do_train` arg being set"
completed_steps = 0
for _epoch in range(self.training_args.num_train_epochs):
self.model.train()
for step, batch in enumerate(self.train_dataloader):
outputs = self.model(**batch)
loss = outputs.loss
loss = loss / self.training_args.gradient_accumulation_steps
self.accelerator.backward(loss)
if step % self.training_args.gradient_accumulation_steps == 0 or step == len(self.train_dataloader) - 1:
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
completed_steps += 1
if completed_steps >= self.training_args.max_steps:
break
self.model.eval()
for step, batch in enumerate(self.eval_dataloader):
outputs = self.model(**batch)
predictions = outputs.logits.argmax(dim=-1) if not self.is_regression else outputs.logits.squeeze()