def main()

in archived/Image_Classification_VIT/scripts/train_images.py [0:0]


def main():
    args = parse_args()

    ##################### SIFTING FLAG #############################
    sifting = (args.use_sifting > 0) 
    ################################################################



    # If passed along, set the training seed now.
    if args.seed is not None:
        torch.manual_seed(args.seed)


    dataset = load_from_disk(args.train_dir)
    # If we don't have a validation split, split off a percentage of train as validation.
    args.train_val_split = None if "validation" in dataset.keys() else args.train_val_split
    if isinstance(args.train_val_split, float) and args.train_val_split > 0.0:
        split = dataset["train"].train_test_split(args.train_val_split)
        dataset["train"] = split["train"]
        dataset["validation"] = split["test"]

    # Prepare label mappings.
    # We'll include these in the model's config to get human readable labels in the Inference API.
    labels = dataset["train"].features["label"].names
    label2id = {label: str(i) for i, label in enumerate(labels)}
    id2label = {str(i): label for i, label in enumerate(labels)}

    # Load pretrained model and image processor
    #
    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    config = AutoConfig.from_pretrained(
        args.model_name_or_path,
        num_labels=len(labels),
        i2label=id2label,
        label2id=label2id,
        finetuning_task="image-classification",
    )
    image_processor = AutoImageProcessor.from_pretrained(args.model_name_or_path)
    model = AutoModelForImageClassification.from_pretrained(
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
        ignore_mismatched_sizes=args.ignore_mismatched_sizes,
    )

    # Preprocessing the datasets

    # Define torchvision transforms to be applied to each image.
    if "shortest_edge" in image_processor.size:
        size = image_processor.size["shortest_edge"]
    else:
        size = (image_processor.size["height"], image_processor.size["width"])
    normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
    train_transforms = Compose(
        [
            RandomResizedCrop(size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )
    val_transforms = Compose(
        [
            Resize(size),
            CenterCrop(size),
            ToTensor(),
            normalize,
        ]
    )

    def preprocess_train(example_batch):
        """Apply _train_transforms across a batch."""
        example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]]
        return example_batch

    def preprocess_val(example_batch):
        """Apply _val_transforms across a batch."""
        example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
        return example_batch

    #with accelerator.main_process_first():
    if args.max_train_samples is not None:
        dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
    # Set the training transforms
    train_dataset = dataset["train"].with_transform(preprocess_train)
    if args.max_eval_samples is not None:
        dataset["validation"] = dataset["validation"].shuffle(seed=args.seed).select(range(args.max_eval_samples))
    # Set the validation transforms
    eval_dataset = dataset["validation"].with_transform(preprocess_val)

    # DataLoaders creation:
    def collate_fn(examples):
        pixel_values = torch.stack([example["pixel_values"] for example in examples])
        labels = torch.tensor([example["label"] for example in examples])
        return {"pixel_values": pixel_values, "labels": labels}

    train_dataloader = DataLoader(
        train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.per_device_train_batch_size
    )

    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    num_training_steps = args.num_train_epochs * len(train_dataloader)

    if (sifting):    
        print("******* will run the training using sifting*********")
        sift_config = RelativeProbabilisticSiftConfig(
            beta_value=3,
            loss_history_length=500,
            loss_based_sift_config=LossConfig(
                 sift_config=SiftingBaseConfig(sift_delay=10)
            )
        )
        train_dataloader = SiftingDataloader(
                sift_config=sift_config,
                orig_dataloader=train_dataloader,
                batch_transforms=ImageListBatchTransform(),
                loss_impl=ImageLoss(),
                model=model
        )        
       
    eval_dataloader = DataLoader(eval_dataset, collate_fn=collate_fn, batch_size=args.per_device_eval_batch_size)

    # Optimizer
    # Split weights in two groups, one with weight decay and the other not.
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    # Scheduler and math around the number of training steps.
    overrode_max_train_steps = False
    
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
        overrode_max_train_steps = True

    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
        num_training_steps=num_training_steps * args.gradient_accumulation_steps,
    )

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    #num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if overrode_max_train_steps:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    # Afterwards we recalculate our number of training epochs
    #args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    # Figure out how many steps we should save the Accelerator states
    checkpointing_steps = args.checkpointing_steps
    if checkpointing_steps is not None and checkpointing_steps.isdigit():
        checkpointing_steps = int(checkpointing_steps)

    # Get the metric function
    metric = evaluate.load("accuracy")
    clf_metrics = evaluate.combine([
        evaluate.load("accuracy",average="weighted"),
        evaluate.load("f1",average="weighted"),
        evaluate.load("precision", average="weighted"),
        evaluate.load("recall", average="weighted")
        ])
    
    # Train!
    total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps

    print("***** Running training *****")
    print(f"  Num examples = {len(train_dataset)}")
    print(f"  Num Epochs = {args.num_train_epochs}")
    print(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
    print(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    print(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    print(f"  Total optimization steps = {args.max_train_steps}")
    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(num_training_steps))
    completed_steps = 0
    starting_epoch = 0

    device = torch.device("cuda")

    model = model.to(device)
    train_step_count = 0

    for epoch in range(starting_epoch, args.num_train_epochs):
        model.train()

        total_loss = 0
        for  batch in train_dataloader:
            train_start = time.perf_counter()

            batch = {k: v.to(device) for k, v, in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
                # We keep track of the loss at each epoch
 
            total_loss += loss.detach().float()
            train_bp_start = time.perf_counter()
            print((f'train forward pass latency: {train_bp_start - train_start}'))
            loss.backward()
            print(f'train backprop latency: {time.perf_counter() - train_bp_start}')
            train_optim_start = time.perf_counter()
            optimizer.step() #gather gradient updates from all cores and apply them
            lr_scheduler.step()
            optimizer.zero_grad()
            print(f'train optimizer step latency: {time.perf_counter() - train_optim_start}')
            print(f'train total step latency: {time.perf_counter() - train_start}')
            train_step_count += 1
            print(f'train step count: {train_step_count}')

            progress_bar.update(1)
            completed_steps += 1


            if completed_steps >= args.max_train_steps:
                break
        print(
            "Epoch {}, Loss {:0.4f}".format(epoch, loss.detach().to("cpu"))
            )       
        model.eval()
        for step, batch in enumerate(eval_dataloader):
            with torch.no_grad():
                batch = {k: v.to(device) for k, v, in batch.items()}
                outputs = model(**batch)
                loss = outputs.loss
            predictions = outputs.logits.argmax(dim=-1)
            references = batch["labels"]
            metric.add_batch(
                predictions=predictions,
                references=references,
            )

        eval_metric = metric.compute()
        print(f"epoch {epoch}: {eval_metric}")
        print(f"epoch {epoch}: eval loss {loss}")


    if args.output_dir is not None:     
        image_processor.save_pretrained(args.output_dir)
        all_results = {f"eval_{k}": v for k, v in eval_metric.items()}
        with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
            json.dump(all_results, f)