in libs/solaris/nets/train.py [0:0]
def train(self):
"""Run training on the model."""
if not self.is_initialized:
self.initialize_model()
if self.framework == 'keras':
self.model.fit_generator(self.train_datagen,
validation_data=self.val_datagen,
epochs=self.epochs,
callbacks=self.callbacks)
elif self.framework == 'torch':
# tf_sess = tf.Session()
for epoch in range(self.epochs):
if self.verbose:
print('Beginning training epoch {}'.format(epoch))
# TRAINING
self.model.train()
for batch_idx, batch in enumerate(self.train_datagen):
if torch.cuda.is_available():
if self.config['data_specs'].get('additional_inputs',
None) is not None:
data = []
for i in ['image'] + self.config[
'data_specs']['additional_inputs']:
data.append(torch.Tensor(batch[i]).cuda())
else:
data = batch['image'].cuda()
target = batch['mask'].cuda().float()
else:
if self.config['data_specs'].get('additional_inputs',
None) is not None:
data = []
for i in ['image'] + self.config[
'data_specs']['additional_inputs']:
data.append(torch.Tensor(batch[i]))
else:
data = batch['image']
target = batch['mask'].float()
self.optimizer.zero_grad()
output = self.model(data)
loss = self.loss(output, target)
loss.backward()
self.optimizer.step()
if self.verbose and batch_idx % 10 == 0:
print(' loss at batch {}: {}'.format(
batch_idx, loss), flush=True)
# calculate metrics
# for metric in self.metrics['train']:
# with tf_sess.as_default():
# print('{} score: {}'.format(
# metric, metric(tf.convert_to_tensor(target.detach().cpu().numpy(), dtype='float64'), tf.convert_to_tensor(output.detach().cpu().numpy(), dtype='float64')).eval()))
# VALIDATION
with torch.no_grad():
self.model.eval()
torch.cuda.empty_cache()
val_loss = []
for batch_idx, batch in enumerate(self.val_datagen):
if torch.cuda.is_available():
if self.config['data_specs'].get(
'additional_inputs', None) is not None:
data = []
for i in ['image'] + self.config[
'data_specs']['additional_inputs']:
data.append(torch.Tensor(batch[i]).cuda())
else:
data = batch['image'].cuda()
target = batch['mask'].cuda().float()
else:
if self.config['data_specs'].get(
'additional_inputs', None) is not None:
data = []
for i in ['image'] + self.config[
'data_specs']['additional_inputs']:
data.append(torch.Tensor(batch[i]))
else:
data = batch['image']
target = batch['mask'].float()
val_output = self.model(data)
val_loss.append(self.loss(val_output, target))
val_loss = torch.mean(torch.stack(val_loss))
if self.verbose:
print()
print(' Validation loss at epoch {}: {}'.format(
epoch, val_loss))
print()
# for metric in self.metrics['val']:
# with tf_sess.as_default():
# print('validation {} score: {}'.format(
# metric, metric(tf.convert_to_tensor(target.detach().cpu().numpy(), dtype='float64'), tf.convert_to_tensor(output.detach().cpu().numpy(), dtype='float64')).eval()))
check_continue = self._run_torch_callbacks(
loss.detach().cpu().numpy(),
val_loss.detach().cpu().numpy())
if not check_continue:
break
self.save_model()