in src/pixparse/task/task_cruller_finetune_xent.py [0:0]
def train_setup(self, num_batches_per_interval: int):
# Load model
# First load base model, then specialize it to fine-tuning end
# FIXME pass along resume arg here
if self.resume:
_logger.info(f"Resuming from existing checkpoint. ")
self.state_dict = {k.replace("module.", ""): v for k, v in self.state_dict.items()}
self.model.load_state_dict(self.state_dict)
self.model = nn.Sequential(
OrderedDict(
[("encoder", self.model.image_encoder),
("token_pool", GetCLSToken()),
("final_fc", nn.Linear(768, 16)), # 16 classes in RVLCDIP
#nn.Softmax(16)
]))
# weights / move to device until here.
device = self.device_env.device
print(f"Local rank for this process: {self.device_env.local_rank}")
device = torch.device(f"cuda:{self.device_env.local_rank}")
self.model.to(device)
if self.device_env.world_size > 1:
# NOTE: the plan is to add option for FSDP w/ HYBRID_SHARD strategy to extend
# model size capacity beyond DDP w/o overloading HF cluster NCCL throughput.
self.model = torch.nn.parallel.DistributedDataParallel(
self.model,
device_ids=[device],
static_graph=True,
)
self.has_no_sync = hasattr(self.model, 'no_sync')
opt_kwargs = {}
if self.cfg.opt.betas is not None:
opt_kwargs['betas'] = self.cfg.opt.betas
if self.cfg.opt.momentum is not None:
opt_kwargs['momentum'] = self.cfg.opt.momentum
# standard opt
self.optimizer = create_optimizer_v2(
self.model,
self.cfg.opt.optimizer,
lr=self.cfg.opt.learning_rate,
eps=self.cfg.opt.eps,
layer_decay=self.cfg.opt.layer_decay,
**opt_kwargs,
)
# only classifier
#self.optimizer = torch.optim.AdamW([p for n, p in self.model.named_parameters() if "final_fc" in n], lr=self.cfg.opt.learning_rate)
if self.cfg.amp:
self.scaler = timm.utils.NativeScaler()
self.autocast = partial(torch.autocast, device_type=device.type, dtype=self.amp_dtype)
else:
self.scaler = None
self.autocast = nullcontext
# FIXME will need two paths here to support interval vs step based durations
# in either case LR is always stepped with each optimizer update (train step)
self.num_steps_per_interval = num_batches_per_interval // self.cfg.opt.grad_accum_steps
self.scheduler, num_scheduled_epochs = create_scheduler_v2(
self.optimizer,
self.cfg.opt.scheduler,
warmup_lr=self.cfg.opt.warmup_learning_rate,
warmup_epochs=self.num_warmup_intervals,
num_epochs=self.num_intervals,
step_on_epochs=False, # sched is stepped on updates
updates_per_epoch=self.num_steps_per_interval,
)
self.scheduler.step_update(0)