def _train()

in datasets/bbbc-021/scripts/bbbc021-1-train-script.py [0:0]


def _train(args):
    
    epochs = args.epochs
    startingEpoch=0
    
    torch.manual_seed(args.seed)

    # NOTE: For Horovod, use: https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/pytorch_horovod_mnist/code/mnist.py
    is_distributed = len(args.hosts) > 1 and args.backend is not None
    logger.debug("Distributed training - {}".format(is_distributed))

    if is_distributed:
        # Initialize the distributed environment.
        world_size = len(args.hosts)
        os.environ['WORLD_SIZE'] = str(world_size)
        host_rank = args.hosts.index(args.current_host)
        #os.environ['RANK'] = str(host_rank)
        dist.init_process_group(backend=args.backend, rank=host_rank, world_size=world_size)
        logger.info(
            'Initialized the distributed environment: \'{}\' backend on {} nodes. '.format(
                args.backend,
                dist.get_world_size()) + 'Current host rank is {}. Using cuda: {}. Number of gpus: {}'.format(
                dist.get_rank(), torch.cuda.is_available(), args.num_gpus))

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    logger.info("Device Type: {}".format(device))
    
    x_train, y_train, z_train = load_training_data(args.train_list_file)

    x_train = x_train.reshape(-1, channels, height_width, height_width)
    
    xts = x_train.shape
    print("x_train reshape=")
    print(xts)
    print("==")
    
    if y_train.shape[0] != z_train.shape[0]:
        print("Error - y_train and z_train must have same length")
        return
    
    classDict = defaultdict(lambda: defaultdict(list))
    for i in range(y_train.shape[0]):
        y = y_train[i]
        z = z_train[i]
        classDict[y][z].append(i)
        
    for ck, cv in classDict.items():
        for sk, sv in cv.items():
            sl = len(sv)
            print("Class {} Subclass {} has {} members".format(ck, sk, sl))
        
    class AnchorPositivePairs():
        def __init__(self):
            self.num_batches = 1

        def __len__(self):
            return self.num_batches

        def getitem(self):
            x = np.empty((2, num_classes, channels, height_width, height_width), dtype=np.float32)
            for class_idx in range(num_classes):
                subclasses_for_class = classDict[class_idx]
                slist = list(subclasses_for_class.values())
                anchor_subclass_list = random.choice(slist)
                positive_subclass_list = random.choice(slist)
                anchor_idx = random.choice(anchor_subclass_list)
                positive_idx = random.choice(positive_subclass_list)
                while positive_idx == anchor_idx:
                    positive_idx = random.choice(positive_subclass_list)
                x[0, class_idx] = (x_train[anchor_idx].astype(np.float32))/65535.0
                x[1, class_idx] = (x_train[positive_idx].astype(np.float32))/65535.0
            
            return torch.tensor(x)

    pairGenerator=AnchorPositivePairs()
    
    checkpointModelPath = os.path.join("/opt/ml/checkpoints", 'model.pth')
    if (os.path.exists(checkpointModelPath)):
        print("Reading checkpoint model")
        model = model_fn("/opt/ml/checkpoints")
        checkpointEpochPath = os.path.join("/opt/ml/checkpoints", "epoch.txt")
        if (os.path.exists(checkpointEpochPath)):
            with open(checkpointEpochPath, "r") as text_file:
                epochStr=text_file.read()
                startingEpoch = int(epochStr)
                print("Resuming beginning with epoch {}".format(startingEpoch))

    else:
        print("No checkpoint model found")
        model = Net()

    if torch.cuda.device_count() > 1:
        logger.info("Gpu count: {}".format(torch.cuda.device_count()))
        model = nn.DataParallel(model)

    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    sparse_labels = torch.zeros(num_classes, dtype=torch.long)
    for l in range(num_classes):
        c = l%num_classes
        sparse_labels[l] = c
    
    sparse_labels = sparse_labels.to(device)

    best_loss = -1.0
    best_epoch = 0
    for epoch in range(startingEpoch, args.epochs):
    
        i=0
        running_loss = 0.0
        epoch_loss = 0.0
        for minibatch in range(minibatches):

            data = pairGenerator.getitem()

            anchors, positives = data[0].to(device), data[1].to(device)
            #anchors, positives = data[0], data[1]
            
            optimizer.zero_grad()
                        
            anchor_embeddings = model(anchors)     
            
            positive_embeddings = model(positives)

            similarities = torch.einsum(
                "ae,pe->ap", anchor_embeddings, positive_embeddings
            )
            
            # Since we intend to use these as logits we scale them by a temperature.
            # This value would normally be chosen as a hyper parameter.
            temperature = 0.2
            similarities /= temperature
            
            # We use these similarities as logits for a softmax. The labels for
            # this call are just the sequence [0, 1, 2, ..., num_classes] since we
            # want the main diagonal values, which correspond to the anchor/positive
            # pairs, to be high. This loss will move embeddings for the
            # anchor/positive pairs together and move all other pairs apart.

            # For CrossEntropyLoss
            loss = criterion(similarities, sparse_labels)
            
            # For CosineEmbeddingLoss
            #loss = criterion(anchor_embeddings, positive_embeddings, y)
            
            loss.backward()
            optimizer.step()

            # print statistics
            item_loss = loss.item()
            running_loss += item_loss
            epoch_loss += item_loss
            if i==0:
                be1 = best_epoch+1
                print("Best epoch={}".format(be1))
            if i % 200 == 199:    
                print('v2 [%d, %5d] loss: %.6f' %
                      (epoch + 1, i + 1, running_loss / 200))
                running_loss = 0.0
            i+=1
        
        if best_loss < 0.0 or epoch_loss < best_loss:
            print("Best loss={} Epoch loss={}".format(best_loss, epoch_loss))
            print("Saving checkpoint to ", args.model_dir)
            nextEpoch=epoch+1
            _save_model(model, args.model_dir, nextEpoch)
            best_loss = epoch_loss
            best_epoch = epoch
        else:
            print("Stopping due to lack of improvement in prior epoch")
            break
            
        model = model.to(device)

    print('Finished Training')