def update()

in ppuda/utils/trainer.py [0:0]


    def update(self, models, images, targets, ghn=None, graphs=None):

        logits = []
        loss = 0

        self.optimizer.zero_grad()

        with torch.cuda.amp.autocast(enabled=self.amp):

            if ghn is not None:
                # Predict parameters
                models = ghn(models, graphs if isinstance(self.device, (list, tuple)) else graphs.to_device(self.device))

            if isinstance(self.device, (list, tuple)):
                # Multigpu training
                assert isinstance(models, (list, tuple)) and isinstance(models[0], (list, tuple)), 'models must be a list of lists'
                image_replicas = [images.to(device, non_blocking=True) for device in self.device]
                targets = targets.to(self.device[0], non_blocking=True)  # loss will be computed on the first device

                models_per_device = len(models[0])      # assume that on the first device the number of models is >= than on other devices
                for ind in range(models_per_device):    # for index withing each device
                    model_replicas = [models[device][ind] for device in self.device if ind < len(models[device])]
                    outputs = parallel_apply(model_replicas,
                                             image_replicas[:len(model_replicas)],
                                             None,
                                             self.device[:len(model_replicas)])  # forward pass at each device in parallel

                    # gather outputs from multiple devices and update the loss on the first device
                    for device, out in zip(self.device, outputs):
                        y = (out[0] if isinstance(out, (list, tuple)) else out).to(self.device[0])

                        loss += self.criterion(y, targets)
                        if self.auxiliary:
                            loss += self.auxiliary_weight * self.criterion(out[1].to(self.device[0]), targets)
                        logits.append(y.detach())

            else:

                images = images.to(self.device, non_blocking=True)
                targets = targets.to(self.device, non_blocking=True)

                if not isinstance(models, (list, tuple)):
                    models = [models]

                for model in models:
                    out = model(images)
                    y = out[0] if isinstance(out, tuple) else out

                    loss += self.criterion(y, targets)
                    if self.auxiliary:
                        loss += self.auxiliary_weight * self.criterion(out[1], targets)

                    logits.append(y.detach())

        loss = loss / len(logits)         # mean loss across models

        if torch.isnan(loss):
            raise RuntimeError('the loss is {}, unable to proceed'.format(loss))

        if self.amp:
            # Scales the loss, and calls backward()
            # to create scaled gradients
            self.scaler.scale(loss).backward()

            # Unscales the gradients of optimizer's assigned params in-place
            self.scaler.unscale_(self.optimizer)
        else:
            loss.backward()

        parameters = []
        for group in self.optimizer.param_groups:
            parameters.extend(group['params'])

        nn.utils.clip_grad_norm_(parameters, self.grad_clip)
        if self.amp:
            # Unscales gradients and calls
            # or skips optimizer.step()
            self.scaler.step(self.optimizer)

            # Updates the scale for next iteration
            self.scaler.update()
        else:
            self.optimizer.step()

        # Concatenate logits across models, duplicate targets accordingly
        logits = torch.stack(logits, dim=0)
        targets = targets.reshape(-1, 1).unsqueeze(0).expand(logits.shape[0], targets.shape[0], 1).reshape(-1)
        logits = logits.reshape(-1, logits.shape[-1])

        # Update training metrics
        prec1, prec5 = accuracy(logits, targets, topk=(1, 5))
        n = len(targets)
        self.metrics['loss'].update(loss.item(), n)
        self.metrics['top1'].update(prec1.item(), n)
        self.metrics['top5'].update(prec5.item(), n)

        self.step += 1

        return loss