in archived/Image_Classification_VIT/scripts/train_images.py [0:0]
def main():
args = parse_args()
##################### SIFTING FLAG #############################
sifting = (args.use_sifting > 0)
################################################################
# If passed along, set the training seed now.
if args.seed is not None:
torch.manual_seed(args.seed)
dataset = load_from_disk(args.train_dir)
# If we don't have a validation split, split off a percentage of train as validation.
args.train_val_split = None if "validation" in dataset.keys() else args.train_val_split
if isinstance(args.train_val_split, float) and args.train_val_split > 0.0:
split = dataset["train"].train_test_split(args.train_val_split)
dataset["train"] = split["train"]
dataset["validation"] = split["test"]
# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
labels = dataset["train"].features["label"].names
label2id = {label: str(i) for i, label in enumerate(labels)}
id2label = {str(i): label for i, label in enumerate(labels)}
# Load pretrained model and image processor
#
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
config = AutoConfig.from_pretrained(
args.model_name_or_path,
num_labels=len(labels),
i2label=id2label,
label2id=label2id,
finetuning_task="image-classification",
)
image_processor = AutoImageProcessor.from_pretrained(args.model_name_or_path)
model = AutoModelForImageClassification.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
ignore_mismatched_sizes=args.ignore_mismatched_sizes,
)
# Preprocessing the datasets
# Define torchvision transforms to be applied to each image.
if "shortest_edge" in image_processor.size:
size = image_processor.size["shortest_edge"]
else:
size = (image_processor.size["height"], image_processor.size["width"])
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
train_transforms = Compose(
[
RandomResizedCrop(size),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
)
val_transforms = Compose(
[
Resize(size),
CenterCrop(size),
ToTensor(),
normalize,
]
)
def preprocess_train(example_batch):
"""Apply _train_transforms across a batch."""
example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]]
return example_batch
def preprocess_val(example_batch):
"""Apply _val_transforms across a batch."""
example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
return example_batch
#with accelerator.main_process_first():
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train)
if args.max_eval_samples is not None:
dataset["validation"] = dataset["validation"].shuffle(seed=args.seed).select(range(args.max_eval_samples))
# Set the validation transforms
eval_dataset = dataset["validation"].with_transform(preprocess_val)
# DataLoaders creation:
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example["label"] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}
train_dataloader = DataLoader(
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.per_device_train_batch_size
)
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_training_steps = args.num_train_epochs * len(train_dataloader)
if (sifting):
print("******* will run the training using sifting*********")
sift_config = RelativeProbabilisticSiftConfig(
beta_value=3,
loss_history_length=500,
loss_based_sift_config=LossConfig(
sift_config=SiftingBaseConfig(sift_delay=10)
)
)
train_dataloader = SiftingDataloader(
sift_config=sift_config,
orig_dataloader=train_dataloader,
batch_transforms=ImageListBatchTransform(),
loss_impl=ImageLoss(),
model=model
)
eval_dataloader = DataLoader(eval_dataset, collate_fn=collate_fn, batch_size=args.per_device_eval_batch_size)
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=num_training_steps * args.gradient_accumulation_steps,
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
#num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
#args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states
checkpointing_steps = args.checkpointing_steps
if checkpointing_steps is not None and checkpointing_steps.isdigit():
checkpointing_steps = int(checkpointing_steps)
# Get the metric function
metric = evaluate.load("accuracy")
clf_metrics = evaluate.combine([
evaluate.load("accuracy",average="weighted"),
evaluate.load("f1",average="weighted"),
evaluate.load("precision", average="weighted"),
evaluate.load("recall", average="weighted")
])
# Train!
total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps
print("***** Running training *****")
print(f" Num examples = {len(train_dataset)}")
print(f" Num Epochs = {args.num_train_epochs}")
print(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
print(f" Total optimization steps = {args.max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(num_training_steps))
completed_steps = 0
starting_epoch = 0
device = torch.device("cuda")
model = model.to(device)
train_step_count = 0
for epoch in range(starting_epoch, args.num_train_epochs):
model.train()
total_loss = 0
for batch in train_dataloader:
train_start = time.perf_counter()
batch = {k: v.to(device) for k, v, in batch.items()}
outputs = model(**batch)
loss = outputs.loss
# We keep track of the loss at each epoch
total_loss += loss.detach().float()
train_bp_start = time.perf_counter()
print((f'train forward pass latency: {train_bp_start - train_start}'))
loss.backward()
print(f'train backprop latency: {time.perf_counter() - train_bp_start}')
train_optim_start = time.perf_counter()
optimizer.step() #gather gradient updates from all cores and apply them
lr_scheduler.step()
optimizer.zero_grad()
print(f'train optimizer step latency: {time.perf_counter() - train_optim_start}')
print(f'train total step latency: {time.perf_counter() - train_start}')
train_step_count += 1
print(f'train step count: {train_step_count}')
progress_bar.update(1)
completed_steps += 1
if completed_steps >= args.max_train_steps:
break
print(
"Epoch {}, Loss {:0.4f}".format(epoch, loss.detach().to("cpu"))
)
model.eval()
for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
batch = {k: v.to(device) for k, v, in batch.items()}
outputs = model(**batch)
loss = outputs.loss
predictions = outputs.logits.argmax(dim=-1)
references = batch["labels"]
metric.add_batch(
predictions=predictions,
references=references,
)
eval_metric = metric.compute()
print(f"epoch {epoch}: {eval_metric}")
print(f"epoch {epoch}: eval loss {loss}")
if args.output_dir is not None:
image_processor.save_pretrained(args.output_dir)
all_results = {f"eval_{k}": v for k, v in eval_metric.items()}
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump(all_results, f)