def train_setup()

in src/pixparse/task/task_cruller_finetune_xent.py [0:0]


    def train_setup(self, num_batches_per_interval: int):
        # Load model

        # First load base model, then specialize it to fine-tuning end
        
        # FIXME pass along resume arg here
        if self.resume:
            _logger.info(f"Resuming from existing checkpoint. ")
            self.state_dict = {k.replace("module.", ""): v for k, v in self.state_dict.items()}
            self.model.load_state_dict(self.state_dict)
        
        self.model = nn.Sequential(
                    OrderedDict(
                        [("encoder", self.model.image_encoder),
                          ("token_pool", GetCLSToken()),
                          ("final_fc", nn.Linear(768, 16)), # 16 classes in RVLCDIP
                          #nn.Softmax(16)
                        ]))
        # weights / move to device until here.
        device = self.device_env.device
        print(f"Local rank for this process: {self.device_env.local_rank}")
        device = torch.device(f"cuda:{self.device_env.local_rank}")
        self.model.to(device)
        if self.device_env.world_size > 1:
            # NOTE: the plan is to add option for FSDP w/ HYBRID_SHARD strategy to extend
            # model size capacity beyond DDP w/o overloading HF cluster NCCL throughput.
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[device],
                static_graph=True,
            )
            self.has_no_sync = hasattr(self.model, 'no_sync')

        opt_kwargs = {}
        if self.cfg.opt.betas is not None:
            opt_kwargs['betas'] = self.cfg.opt.betas
        if self.cfg.opt.momentum is not None:
            opt_kwargs['momentum'] = self.cfg.opt.momentum

        
        
        # standard opt

        self.optimizer = create_optimizer_v2(
            self.model,
            self.cfg.opt.optimizer,
            lr=self.cfg.opt.learning_rate,
            eps=self.cfg.opt.eps,
            layer_decay=self.cfg.opt.layer_decay,
            **opt_kwargs,
        )
        

        #  only classifier

        #self.optimizer = torch.optim.AdamW([p for n, p in self.model.named_parameters() if "final_fc" in n], lr=self.cfg.opt.learning_rate)

        if self.cfg.amp:
            self.scaler = timm.utils.NativeScaler()
            self.autocast = partial(torch.autocast, device_type=device.type, dtype=self.amp_dtype)
        else:
            self.scaler = None
            self.autocast = nullcontext

        # FIXME will need two paths here to support interval vs step based durations
        #  in either case LR is always stepped with each optimizer update (train step)
        self.num_steps_per_interval = num_batches_per_interval // self.cfg.opt.grad_accum_steps
        self.scheduler, num_scheduled_epochs = create_scheduler_v2(
            self.optimizer,
            self.cfg.opt.scheduler,
            warmup_lr=self.cfg.opt.warmup_learning_rate,
            warmup_epochs=self.num_warmup_intervals,
            num_epochs=self.num_intervals,
            step_on_epochs=False,  # sched is stepped on updates
            updates_per_epoch=self.num_steps_per_interval,
        )
        self.scheduler.step_update(0)