in archived/nas_for_llm_with_amt/training.py [0:0]
def main():
parser = HfArgumentParser(
(ModelArguments, DataTrainingArguments, TrainingArguments)
)
(
model_args,
data_args,
training_args,
) = parser.parse_args_into_dataclasses()
model_type = model_args.model_name_or_path
task_name = data_args.task_name
seed = training_args.seed
per_device_train_batch_size = training_args.per_device_train_batch_size
per_device_eval_batch_size = training_args.per_device_eval_batch_size
tokenizer = AutoTokenizer.from_pretrained(model_type)
padding = "max_length"
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
def preprocess_function(examples):
args = (
(examples[sentence1_key],)
if sentence2_key is None
else (examples[sentence1_key], examples[sentence2_key])
)
result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
return result
raw_datasets = load_dataset("glue", task_name)
sentence1_key, sentence2_key = ("sentence1", "sentence2")
metric = evaluate.load("glue", task_name)
preproc_datasets = raw_datasets.map(
preprocess_function,
batched=True,
desc="Running tokenizer on dataset",
)
label_list = preproc_datasets["train"].features["label"].names
num_labels = len(label_list)
train_dataset = preproc_datasets["train"]
train_dataset = train_dataset.remove_columns(["idx"])
split = train_dataset.train_test_split(train_size=0.7, seed=seed)
train_dataset = split["train"]
valid_dataset = split["test"]
data_collator = default_data_collator
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
batch_size=per_device_train_batch_size,
collate_fn=data_collator,
)
eval_dataloader = DataLoader(
valid_dataset,
batch_size=per_device_eval_batch_size,
collate_fn=data_collator,
)
config = AutoConfig.from_pretrained(
model_type,
num_labels=num_labels,
finetuning_task=task_name,
)
model = AutoModelForSequenceClassification.from_pretrained(
model_type,
config=config,
)
optimizer = AdamW(model.parameters(), lr=training_args.learning_rate)
num_training_steps = int(training_args.num_train_epochs * len(train_dataloader))
warmup_steps = int(training_args.warmup_ratio * num_training_steps)
lr_scheduler = get_scheduler(
name="linear",
optimizer=optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=num_training_steps,
)
progress_bar = tqdm(range(num_training_steps))
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
step = 0
if data_args.is_regression:
distillation_loss = nn.MSELoss()
else:
kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)
distillation_loss = lambda x, y: kl_loss(
F.log_softmax(x, dim=-1), F.log_softmax(y, dim=-1)
)
model_type = model.config._name_or_path
if model_type.startswith("bert"):
mask = mask_bert
else:
raise AttributeError(f'Model {model_type} is not supported at this point!')
sampler = SmallSearchSpace(
model.config, rng=np.random.RandomState(seed=training_args.seed)
)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
for epoch in range(int(training_args.num_train_epochs)):
model.train()
train_loss = 0
for batch in train_dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
# update largest sub-network (i.e super-network)
outputs = model(**batch)
loss = outputs.loss
y_teacher = outputs.logits.detach()
loss.backward()
# update smallest sub-network
head_mask, ffn_mask = sampler.get_smallest_sub_network()
head_mask = head_mask.to(device=device, dtype=model.dtype)
ffn_mask = ffn_mask.to(device=device, dtype=model.dtype)
handles = mask(model, ffn_mask, head_mask)
outputs = model(head_mask=head_mask, **batch)
for handle in handles:
handle.remove()
loss = distillation_loss(outputs.logits, y_teacher)
loss.backward()
# update random sub-network
head_mask, ffn_mask = sampler()
head_mask = head_mask.to(device=device, dtype=model.dtype)
ffn_mask = ffn_mask.to(device=device, dtype=model.dtype)
handles = mask(model, ffn_mask, head_mask)
outputs = model(head_mask=head_mask, **batch)
for handle in handles:
handle.remove()
loss = distillation_loss(outputs.logits, y_teacher)
loss.backward()
# update random sub-network
head_mask, ffn_mask = sampler()
head_mask = head_mask.to(device=device, dtype=model.dtype)
ffn_mask = ffn_mask.to(device=device, dtype=model.dtype)
handles = mask(model, ffn_mask, head_mask)
outputs = model(head_mask=head_mask, **batch)
for handle in handles:
handle.remove()
loss = distillation_loss(outputs.logits, y_teacher)
loss.backward()
step += 1
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
train_loss += loss
model.eval()
for batch in eval_dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
logits = outputs.logits
predictions = (
torch.squeeze(logits)
if data_args.is_regression
else torch.argmax(logits, dim=-1)
)
metric.add_batch(predictions=predictions, references=batch["labels"])
eval_metric = metric.compute()
metric_name = TASKINFO[data_args.task_name]["metric"]
print(f"epoch: {epoch}")
print(f"training loss: {train_loss / len(train_dataloader)}")
print(f"number of parameters: {n_params}")
print(f"validation error: {1 - eval_metric[metric_name]}")
if training_args.save_strategy == "epoch":
os.makedirs(training_args.output_dir, exist_ok=True)
logging.info(f"Store checkpoint in: {training_args.output_dir}")
model.save_pretrained('checkpoint')
shutil.make_archive(
base_name=training_args.output_dir + '/model',
format='gztar',
root_dir='checkpoint'
)