in src/pixparse/task/task_cruller_finetune_CORD.py [0:0]
def train_step(self, sample: Dict[str, Any]) -> Dict[str, Any]:
image_input = sample["image"]
label = sample["label"]
text_target = sample["text_target"]
result = {}
image_input = image_input.to(self.device_env.device, non_blocking=True)
label = label.to(self.device_env.device, non_blocking=True)
text_target = text_target.to(self.device_env.device, non_blocking=True)
accum_steps = self.cfg.opt.grad_accum_steps
need_update = (self.interval_batch_idx + 1) % accum_steps == 0
def _forward():
with self.autocast():
if self.finetune_donut_weights:
#print(image_input.shape, label.shape, text_target.shape)
output = self.model(pixel_values=image_input, decoder_input_ids=label, labels=text_target)
logits = output["logits"]
else:
output = self.model(image_input, label)
logits = output["logits"]
#print(logits.shape, text_target.shape)
loss = self.loss(
logits.view(-1, self.vocab_size),
text_target.view(-1),
)
#print(loss.item())
if accum_steps > 1:
loss /= accum_steps
return loss
def _backward(_loss):
if self.scaler is not None:
self.scaler(
_loss,
self.optimizer,
clip_grad=self.cfg.opt.clip_grad_value,
clip_mode=self.cfg.opt.clip_grad_mode,
parameters=self.model.parameters(),
need_update=need_update,
)
else:
_loss.backward()
if need_update:
if self.cfg.opt.clip_grad_value is not None:
timm.utils.dispatch_clip_grad(
self.model.parameters(),
value=self.cfg.opt.clip_grad_value,
mode=self.cfg.opt.clip_grad_mode,
)
self.optimizer.step()
if self.has_no_sync and not need_update:
with self.model.no_sync():
loss = _forward()
_backward(loss)
else:
loss = _forward()
_backward(loss)
self.batch_idx += 1
self.interval_batch_idx += 1
if self.step % 100 == 0:
self.monitor.log_step(
"finetune",
step_idx=self.step,
step_end_idx=self.num_intervals * self.num_steps_per_interval,
interval=self.interval_idx,
loss=loss.item(),
lr=self.get_current_lr(),
metrics=None,
eval_data=None,
)
if not need_update:
return result
self.step += 1
self.scheduler.step_update(self.step)
self.optimizer.zero_grad()