in source/monai_skin_cancer.py [0:0]
def train(args):
is_distributed = len(args.hosts) > 1 and args.backend is not None
logger.debug("Distributed training - {}".format(is_distributed))
use_cuda = args.num_gpus > 0
logger.debug("Number of gpus available - {}".format(args.num_gpus))
kwargs = {'num_workers': 10, 'pin_memory': True} if use_cuda else {}
device = torch.device("cuda" if use_cuda else "cpu")
if is_distributed:
# Initialize the distributed environment.
world_size = len(args.hosts)
os.environ['WORLD_SIZE'] = str(world_size)
host_rank = args.hosts.index(args.current_host)
os.environ['RANK'] = str(host_rank)
dist.init_process_group(backend=args.backend, rank=host_rank, world_size=world_size)
logger.debug('Initialized the distributed environment: \'{}\' backend on {} nodes. '.format(
args.backend, dist.get_world_size()) + 'Current host rank is {}. Number of gpus: {}'.format(
dist.get_rank(), args.num_gpus))
# set the seed for generating random numbers
torch.manual_seed(args.seed)
if use_cuda:
torch.cuda.manual_seed(args.seed)
torchtext.utils.extract_archive(args.data_dir+'/HAM10000.tar.gz', args.data_dir)
#build file lists
data_dir = args.data_dir+'/HAM10000/train_dir'
class_names = sorted([x for x in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, x))])
num_class = len(class_names)
image_files = [[os.path.join(data_dir, class_name, x)
for x in os.listdir(os.path.join(data_dir, class_name))]
for class_name in class_names]
image_file_list = []
image_label_list = []
for i, class_name in enumerate(class_names):
image_file_list.extend(image_files[i])
image_label_list.extend([i] * len(image_files[i]))
num_total = len(image_label_list)
image_width, image_height = Image.open(image_file_list[0]).size
valid_frac, test_frac = 0.1, 0.1
trainX, trainY = [], []
valX, valY = [], []
testX, testY = [], []
for i in range(num_total):
rann = np.random.random()
if rann < valid_frac:
valX.append(image_file_list[i])
valY.append(image_label_list[i])
elif rann < test_frac + valid_frac:
testX.append(image_file_list[i])
testY.append(image_label_list[i])
else:
trainX.append(image_file_list[i])
trainY.append(image_label_list[i])
print("Training count =",len(trainX),"Validation count =", len(valX), "Test count =",len(testX))
train_loader = _get_train_data_loader(args.batch_size, trainX, trainY, False, **kwargs)
val_loader = _get_test_data_loader(args.test_batch_size, valX, valY, **kwargs)
test_loader = _get_test_data_loader(args.test_batch_size, testX, testY, **kwargs)
#create model
model = densenet121(
spatial_dims=2,
in_channels=3,
out_channels=num_class
).to(device)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-5)
epoch_num = args.epochs
val_interval = 1
#train model
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
for epoch in range(epoch_num):
logger.info('-' * 10)
logger.info(f"epoch {epoch + 1}/{epoch_num}")
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step += 1
inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
logger.info(f"{step}/{len(train_loader.dataset) // train_loader.batch_size}, train_loss: {loss.item():.4f}")
epoch_len = len(train_loader.dataset) // train_loader.batch_size
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
logger.info(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
if (epoch + 1) % val_interval == 0:
model.eval()
with torch.no_grad():
y_pred = torch.tensor([], dtype=torch.float32, device=device)
y = torch.tensor([], dtype=torch.long, device=device)
for val_data in val_loader:
val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
y_pred = torch.cat([y_pred, model(val_images)], dim=0)
y = torch.cat([y, val_labels], dim=0)
auc_metric = compute_roc_auc(y_pred, y, to_onehot_y=True, softmax=True)
metric_values.append(auc_metric)
acc_value = torch.eq(y_pred.argmax(dim=1), y)
acc_metric = acc_value.sum().item() / len(acc_value)
if auc_metric > best_metric:
best_metric = auc_metric
best_metric_epoch = epoch + 1
torch.save(model.state_dict(), 'best_metric_model.pth')
logger.info('saved new best metric model')
logger.info(f"current epoch: {epoch + 1} current AUC: {auc_metric:.4f}"
f" current accuracy: {acc_metric:.4f} best AUC: {best_metric:.4f}"
f" at epoch: {best_metric_epoch}")
logger.info(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
save_model(model, args.model_dir)
#test data:classification report
model.load_state_dict(torch.load('best_metric_model.pth'))
model.to(device)
model.eval()
y_true = list()
y_pred = list()
with torch.no_grad():
for test_data in test_loader:
test_images, test_labels = test_data[0].to(device), test_data[1].to(device)
pred = model(test_images).argmax(dim=1)
for i in range(len(pred)):
y_true.append(test_labels[i].item())
y_pred.append(pred[i].item())
from sklearn.metrics import classification_report
logger.info(classification_report(y_true, y_pred, target_names=class_names, digits=4))