in ppuda/utils/trainer.py [0:0]
def update(self, models, images, targets, ghn=None, graphs=None):
logits = []
loss = 0
self.optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=self.amp):
if ghn is not None:
# Predict parameters
models = ghn(models, graphs if isinstance(self.device, (list, tuple)) else graphs.to_device(self.device))
if isinstance(self.device, (list, tuple)):
# Multigpu training
assert isinstance(models, (list, tuple)) and isinstance(models[0], (list, tuple)), 'models must be a list of lists'
image_replicas = [images.to(device, non_blocking=True) for device in self.device]
targets = targets.to(self.device[0], non_blocking=True) # loss will be computed on the first device
models_per_device = len(models[0]) # assume that on the first device the number of models is >= than on other devices
for ind in range(models_per_device): # for index withing each device
model_replicas = [models[device][ind] for device in self.device if ind < len(models[device])]
outputs = parallel_apply(model_replicas,
image_replicas[:len(model_replicas)],
None,
self.device[:len(model_replicas)]) # forward pass at each device in parallel
# gather outputs from multiple devices and update the loss on the first device
for device, out in zip(self.device, outputs):
y = (out[0] if isinstance(out, (list, tuple)) else out).to(self.device[0])
loss += self.criterion(y, targets)
if self.auxiliary:
loss += self.auxiliary_weight * self.criterion(out[1].to(self.device[0]), targets)
logits.append(y.detach())
else:
images = images.to(self.device, non_blocking=True)
targets = targets.to(self.device, non_blocking=True)
if not isinstance(models, (list, tuple)):
models = [models]
for model in models:
out = model(images)
y = out[0] if isinstance(out, tuple) else out
loss += self.criterion(y, targets)
if self.auxiliary:
loss += self.auxiliary_weight * self.criterion(out[1], targets)
logits.append(y.detach())
loss = loss / len(logits) # mean loss across models
if torch.isnan(loss):
raise RuntimeError('the loss is {}, unable to proceed'.format(loss))
if self.amp:
# Scales the loss, and calls backward()
# to create scaled gradients
self.scaler.scale(loss).backward()
# Unscales the gradients of optimizer's assigned params in-place
self.scaler.unscale_(self.optimizer)
else:
loss.backward()
parameters = []
for group in self.optimizer.param_groups:
parameters.extend(group['params'])
nn.utils.clip_grad_norm_(parameters, self.grad_clip)
if self.amp:
# Unscales gradients and calls
# or skips optimizer.step()
self.scaler.step(self.optimizer)
# Updates the scale for next iteration
self.scaler.update()
else:
self.optimizer.step()
# Concatenate logits across models, duplicate targets accordingly
logits = torch.stack(logits, dim=0)
targets = targets.reshape(-1, 1).unsqueeze(0).expand(logits.shape[0], targets.shape[0], 1).reshape(-1)
logits = logits.reshape(-1, logits.shape[-1])
# Update training metrics
prec1, prec5 = accuracy(logits, targets, topk=(1, 5))
n = len(targets)
self.metrics['loss'].update(loss.item(), n)
self.metrics['top1'].update(prec1.item(), n)
self.metrics['top5'].update(prec5.item(), n)
self.step += 1
return loss