in src/pixparse/task/task_cruller_pretrain.py [0:0]
def train_step(self, sample):
image_input, text_input, text_target = sample
result = {}
image_input = image_input.to(self.device_env.device, non_blocking=True)
text_input = text_input[:, :-1].to(self.device_env.device, non_blocking=True)
text_target = text_target[:, 1:].to(self.device_env.device, non_blocking=True)
accum_steps = self.cfg.opt.grad_accum_steps
need_update = (self.interval_batch_idx + 1) % accum_steps == 0
def _forward():
with self.autocast():
output = self.model(image_input, text_input)
logits = output['logits']
loss = self.loss(
logits.view(-1, self.vocab_size),
text_target.view(-1),
)
if accum_steps > 1:
loss /= accum_steps
return loss
def _backward(_loss):
if self.scaler is not None:
self.scaler(
_loss,
self.optimizer,
clip_grad=self.cfg.opt.clip_grad_value,
clip_mode=self.cfg.opt.clip_grad_mode,
parameters=self.model.parameters(),
need_update=need_update,
)
else:
_loss.backward()
if need_update:
if self.cfg.opt.clip_grad_value is not None:
timm.utils.dispatch_clip_grad(
self.model.parameters(),
value=self.cfg.opt.clip_grad_value,
mode=self.cfg.opt.clip_grad_mode,
)
self.optimizer.step()
if self.has_no_sync and not need_update:
with self.model.no_sync():
loss = _forward()
_backward(loss)
else:
loss = _forward()
_backward(loss)
self.batch_idx += 1
self.interval_batch_idx += 1
if not need_update:
return result
self.step += 1
self.scheduler.step_update(self.step)
self.optimizer.zero_grad()
if self.step % self.eval_frequency == 0:
metrics, eval_gallery = self.get_train_ocr_metrics(sample)
self.train_metrics |= metrics
self.monitor.log_step(
'train',
step_idx=self.step,
step_end_idx=self.num_intervals * self.num_steps_per_interval,
interval=self.interval_idx,
loss=loss.item(),
lr=self.get_current_lr(),
metrics=self.train_metrics,
eval_data=eval_gallery
)
return result