def train()

in notebook/source/monai_dicom.py [0:0]


def train(args):
    is_distributed = len(args.hosts) > 1 and args.backend is not None
    logger.debug("Distributed training - {}".format(is_distributed))
    use_cuda = args.num_gpus > 0
    logger.debug("Number of gpus available - {}".format(args.num_gpus))
    kwargs = {'num_workers': 10, 'pin_memory': True} if use_cuda else {}
    device = torch.device("cuda" if use_cuda else "cpu")
        
    if is_distributed:
        # Initialize the distributed environment.
        world_size = len(args.hosts)
        os.environ['WORLD_SIZE'] = str(world_size)
        host_rank = args.hosts.index(args.current_host)
        os.environ['RANK'] = str(host_rank)
        dist.init_process_group(backend=args.backend, rank=host_rank, world_size=world_size)
        logger.debug('Initialized the distributed environment: \'{}\' backend on {} nodes. '.format(
            args.backend, dist.get_world_size()) + 'Current host rank is {}. Number of gpus: {}'.format(
            dist.get_rank(), args.num_gpus))

    # set the seed for generating random numbers
    torch.manual_seed(args.seed)
    if use_cuda:
        torch.cuda.manual_seed(args.seed)
        
    #build file lists
    image_label_list = []
    image_file_list = []
    metadata = args.data_dir+'/meta-data.json'   
    
    # Load Labels
    with open(metadata) as f:
        manifest = json.load(f)
        class_names = list(json.loads(manifest[0]['annotations'][0]['annotationData']['content'])['disease'].keys())
        num_class = len(class_names)
    
    for i, j in enumerate(manifest):
        label_dict = json.loads(json.loads(manifest[i]['annotations'][0]['annotationData']['content'])['labels'])
        filename = args.data_dir+'/'+str([label_dict['imageurl']]).split("/file")[0].split("instances/")[1] + '.dcm'
        image_file_list.append(filename)
        image_label_list.extend([class_names.index(label_dict['label'][0])])
    
    print("Training count =",len(image_file_list))
            
    train_loader = _get_train_data_loader(args.batch_size, image_file_list, image_label_list, False, **kwargs)

    #create model
    model = densenet121(
        spatial_dims=2,
        in_channels=1,
        out_channels=num_class
    ).to(device)
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), 1e-5)
    epoch_num = args.epochs
    val_interval = 1

    #train model
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    for epoch in range(epoch_num):
        logger.info('-' * 10)
        logger.info(f"epoch {epoch + 1}/{epoch_num}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            logger.info(f"{step}/{len(train_loader.dataset) // train_loader.batch_size}, train_loss: {loss.item():.4f}")
            epoch_len = len(train_loader.dataset) // train_loader.batch_size        
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        logger.info(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    save_model(model, args.model_dir)