def train()

in libs/solaris/nets/train.py [0:0]


    def train(self):
        """Run training on the model."""
        if not self.is_initialized:
            self.initialize_model()

        if self.framework == 'keras':
            self.model.fit_generator(self.train_datagen,
                                     validation_data=self.val_datagen,
                                     epochs=self.epochs,
                                     callbacks=self.callbacks)

        elif self.framework == 'torch':
#            tf_sess = tf.Session()
            for epoch in range(self.epochs):
                if self.verbose:
                    print('Beginning training epoch {}'.format(epoch))
                # TRAINING
                self.model.train()
                for batch_idx, batch in enumerate(self.train_datagen):
                    if torch.cuda.is_available():
                        if self.config['data_specs'].get('additional_inputs',
                                                         None) is not None:
                            data = []
                            for i in ['image'] + self.config[
                                    'data_specs']['additional_inputs']:
                                data.append(torch.Tensor(batch[i]).cuda())
                        else:
                            data = batch['image'].cuda()
                        target = batch['mask'].cuda().float()
                    else:
                        if self.config['data_specs'].get('additional_inputs',
                                                         None) is not None:
                            data = []
                            for i in ['image'] + self.config[
                                    'data_specs']['additional_inputs']:
                                data.append(torch.Tensor(batch[i]))
                        else:
                            data = batch['image']
                        target = batch['mask'].float()
                    self.optimizer.zero_grad()
                    output = self.model(data)
                    loss = self.loss(output, target)
                    loss.backward()
                    self.optimizer.step()

                    if self.verbose and batch_idx % 10 == 0:

                        print('    loss at batch {}: {}'.format(
                            batch_idx, loss), flush=True)
                        # calculate metrics
#                        for metric in self.metrics['train']:
#                            with tf_sess.as_default():
#                                print('{} score: {}'.format(
#                                    metric, metric(tf.convert_to_tensor(target.detach().cpu().numpy(), dtype='float64'), tf.convert_to_tensor(output.detach().cpu().numpy(), dtype='float64')).eval()))
                # VALIDATION
                with torch.no_grad():
                    self.model.eval()
                    torch.cuda.empty_cache()
                    val_loss = []
                    for batch_idx, batch in enumerate(self.val_datagen):
                        if torch.cuda.is_available():
                            if self.config['data_specs'].get(
                                    'additional_inputs', None) is not None:
                                data = []
                                for i in ['image'] + self.config[
                                        'data_specs']['additional_inputs']:
                                    data.append(torch.Tensor(batch[i]).cuda())
                            else:
                                data = batch['image'].cuda()
                            target = batch['mask'].cuda().float()
                        else:
                            if self.config['data_specs'].get(
                                    'additional_inputs', None) is not None:
                                data = []
                                for i in ['image'] + self.config[
                                        'data_specs']['additional_inputs']:
                                    data.append(torch.Tensor(batch[i]))
                            else:
                                data = batch['image']
                            target = batch['mask'].float()
                        val_output = self.model(data)
                        val_loss.append(self.loss(val_output, target))
                    val_loss = torch.mean(torch.stack(val_loss))
                if self.verbose:
                    print()
                    print('    Validation loss at epoch {}: {}'.format(
                        epoch, val_loss))
                    print()
#                    for metric in self.metrics['val']:
#                        with tf_sess.as_default():
#                            print('validation {} score: {}'.format(
#                            metric, metric(tf.convert_to_tensor(target.detach().cpu().numpy(), dtype='float64'), tf.convert_to_tensor(output.detach().cpu().numpy(), dtype='float64')).eval()))
                check_continue = self._run_torch_callbacks(
                    loss.detach().cpu().numpy(),
                    val_loss.detach().cpu().numpy())
                if not check_continue:
                    break

            self.save_model()