in notebook/source/monai_dicom.py [0:0]
def train(args):
is_distributed = len(args.hosts) > 1 and args.backend is not None
logger.debug("Distributed training - {}".format(is_distributed))
use_cuda = args.num_gpus > 0
logger.debug("Number of gpus available - {}".format(args.num_gpus))
kwargs = {'num_workers': 10, 'pin_memory': True} if use_cuda else {}
device = torch.device("cuda" if use_cuda else "cpu")
if is_distributed:
# Initialize the distributed environment.
world_size = len(args.hosts)
os.environ['WORLD_SIZE'] = str(world_size)
host_rank = args.hosts.index(args.current_host)
os.environ['RANK'] = str(host_rank)
dist.init_process_group(backend=args.backend, rank=host_rank, world_size=world_size)
logger.debug('Initialized the distributed environment: \'{}\' backend on {} nodes. '.format(
args.backend, dist.get_world_size()) + 'Current host rank is {}. Number of gpus: {}'.format(
dist.get_rank(), args.num_gpus))
# set the seed for generating random numbers
torch.manual_seed(args.seed)
if use_cuda:
torch.cuda.manual_seed(args.seed)
#build file lists
image_label_list = []
image_file_list = []
metadata = args.data_dir+'/meta-data.json'
# Load Labels
with open(metadata) as f:
manifest = json.load(f)
class_names = list(json.loads(manifest[0]['annotations'][0]['annotationData']['content'])['disease'].keys())
num_class = len(class_names)
for i, j in enumerate(manifest):
label_dict = json.loads(json.loads(manifest[i]['annotations'][0]['annotationData']['content'])['labels'])
filename = args.data_dir+'/'+str([label_dict['imageurl']]).split("/file")[0].split("instances/")[1] + '.dcm'
image_file_list.append(filename)
image_label_list.extend([class_names.index(label_dict['label'][0])])
print("Training count =",len(image_file_list))
train_loader = _get_train_data_loader(args.batch_size, image_file_list, image_label_list, False, **kwargs)
#create model
model = densenet121(
spatial_dims=2,
in_channels=1,
out_channels=num_class
).to(device)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-5)
epoch_num = args.epochs
val_interval = 1
#train model
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
for epoch in range(epoch_num):
logger.info('-' * 10)
logger.info(f"epoch {epoch + 1}/{epoch_num}")
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step += 1
inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
logger.info(f"{step}/{len(train_loader.dataset) // train_loader.batch_size}, train_loss: {loss.item():.4f}")
epoch_len = len(train_loader.dataset) // train_loader.batch_size
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
logger.info(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
save_model(model, args.model_dir)