in src/pixparse/task/task_cruller_finetune_xent.py [0:0]
def train_step(self, sample: Dict[str, Any]) -> Dict[str, Any]:
image_input = sample['image']
label = sample['label']
result = {}
image_input = image_input.to(self.device_env.device, non_blocking=True)
label = label.to(self.device_env.device, non_blocking=True)
accum_steps = self.cfg.opt.grad_accum_steps
need_update = (self.interval_batch_idx + 1) % accum_steps == 0
def _forward():
with self.autocast():
outputs = self.model(image_input)
loss = self.loss(
outputs,
label,
)
if accum_steps > 1:
loss /= accum_steps
return loss
def _backward(_loss):
if self.scaler is not None:
self.scaler(
_loss,
self.optimizer,
clip_grad=self.cfg.opt.clip_grad_value,
clip_mode=self.cfg.opt.clip_grad_mode,
parameters=self.model.parameters(),
need_update=need_update,
)
else:
_loss.backward()
if need_update:
if self.cfg.opt.clip_grad_value is not None:
timm.utils.dispatch_clip_grad(
self.model.parameters(),
value=self.cfg.opt.clip_grad_value,
mode=self.cfg.opt.clip_grad_mode,
)
self.optimizer.step()
if self.has_no_sync and not need_update:
with self.model.no_sync():
loss = _forward()
_backward(loss)
else:
loss = _forward()
_backward(loss)
self.batch_idx += 1
self.interval_batch_idx += 1
if self.step % self.eval_frequency == 0:
self.monitor.log_step(
'finetune',
step_idx=self.step,
step_end_idx=self.num_intervals * self.num_steps_per_interval,
interval=self.interval_idx,
loss=loss.item(),
lr=self.get_current_lr(),
metrics=None,
eval_data=None
)
if not need_update:
return result
self.step += 1
self.scheduler.step_update(self.step)
self.optimizer.zero_grad()