in datasets/bbbc-021/scripts/bbbc021-1-train-script.py [0:0]
def _train(args):
epochs = args.epochs
startingEpoch=0
torch.manual_seed(args.seed)
# NOTE: For Horovod, use: https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/pytorch_horovod_mnist/code/mnist.py
is_distributed = len(args.hosts) > 1 and args.backend is not None
logger.debug("Distributed training - {}".format(is_distributed))
if is_distributed:
# Initialize the distributed environment.
world_size = len(args.hosts)
os.environ['WORLD_SIZE'] = str(world_size)
host_rank = args.hosts.index(args.current_host)
#os.environ['RANK'] = str(host_rank)
dist.init_process_group(backend=args.backend, rank=host_rank, world_size=world_size)
logger.info(
'Initialized the distributed environment: \'{}\' backend on {} nodes. '.format(
args.backend,
dist.get_world_size()) + 'Current host rank is {}. Using cuda: {}. Number of gpus: {}'.format(
dist.get_rank(), torch.cuda.is_available(), args.num_gpus))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info("Device Type: {}".format(device))
x_train, y_train, z_train = load_training_data(args.train_list_file)
x_train = x_train.reshape(-1, channels, height_width, height_width)
xts = x_train.shape
print("x_train reshape=")
print(xts)
print("==")
if y_train.shape[0] != z_train.shape[0]:
print("Error - y_train and z_train must have same length")
return
classDict = defaultdict(lambda: defaultdict(list))
for i in range(y_train.shape[0]):
y = y_train[i]
z = z_train[i]
classDict[y][z].append(i)
for ck, cv in classDict.items():
for sk, sv in cv.items():
sl = len(sv)
print("Class {} Subclass {} has {} members".format(ck, sk, sl))
class AnchorPositivePairs():
def __init__(self):
self.num_batches = 1
def __len__(self):
return self.num_batches
def getitem(self):
x = np.empty((2, num_classes, channels, height_width, height_width), dtype=np.float32)
for class_idx in range(num_classes):
subclasses_for_class = classDict[class_idx]
slist = list(subclasses_for_class.values())
anchor_subclass_list = random.choice(slist)
positive_subclass_list = random.choice(slist)
anchor_idx = random.choice(anchor_subclass_list)
positive_idx = random.choice(positive_subclass_list)
while positive_idx == anchor_idx:
positive_idx = random.choice(positive_subclass_list)
x[0, class_idx] = (x_train[anchor_idx].astype(np.float32))/65535.0
x[1, class_idx] = (x_train[positive_idx].astype(np.float32))/65535.0
return torch.tensor(x)
pairGenerator=AnchorPositivePairs()
checkpointModelPath = os.path.join("/opt/ml/checkpoints", 'model.pth')
if (os.path.exists(checkpointModelPath)):
print("Reading checkpoint model")
model = model_fn("/opt/ml/checkpoints")
checkpointEpochPath = os.path.join("/opt/ml/checkpoints", "epoch.txt")
if (os.path.exists(checkpointEpochPath)):
with open(checkpointEpochPath, "r") as text_file:
epochStr=text_file.read()
startingEpoch = int(epochStr)
print("Resuming beginning with epoch {}".format(startingEpoch))
else:
print("No checkpoint model found")
model = Net()
if torch.cuda.device_count() > 1:
logger.info("Gpu count: {}".format(torch.cuda.device_count()))
model = nn.DataParallel(model)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
sparse_labels = torch.zeros(num_classes, dtype=torch.long)
for l in range(num_classes):
c = l%num_classes
sparse_labels[l] = c
sparse_labels = sparse_labels.to(device)
best_loss = -1.0
best_epoch = 0
for epoch in range(startingEpoch, args.epochs):
i=0
running_loss = 0.0
epoch_loss = 0.0
for minibatch in range(minibatches):
data = pairGenerator.getitem()
anchors, positives = data[0].to(device), data[1].to(device)
#anchors, positives = data[0], data[1]
optimizer.zero_grad()
anchor_embeddings = model(anchors)
positive_embeddings = model(positives)
similarities = torch.einsum(
"ae,pe->ap", anchor_embeddings, positive_embeddings
)
# Since we intend to use these as logits we scale them by a temperature.
# This value would normally be chosen as a hyper parameter.
temperature = 0.2
similarities /= temperature
# We use these similarities as logits for a softmax. The labels for
# this call are just the sequence [0, 1, 2, ..., num_classes] since we
# want the main diagonal values, which correspond to the anchor/positive
# pairs, to be high. This loss will move embeddings for the
# anchor/positive pairs together and move all other pairs apart.
# For CrossEntropyLoss
loss = criterion(similarities, sparse_labels)
# For CosineEmbeddingLoss
#loss = criterion(anchor_embeddings, positive_embeddings, y)
loss.backward()
optimizer.step()
# print statistics
item_loss = loss.item()
running_loss += item_loss
epoch_loss += item_loss
if i==0:
be1 = best_epoch+1
print("Best epoch={}".format(be1))
if i % 200 == 199:
print('v2 [%d, %5d] loss: %.6f' %
(epoch + 1, i + 1, running_loss / 200))
running_loss = 0.0
i+=1
if best_loss < 0.0 or epoch_loss < best_loss:
print("Best loss={} Epoch loss={}".format(best_loss, epoch_loss))
print("Saving checkpoint to ", args.model_dir)
nextEpoch=epoch+1
_save_model(model, args.model_dir, nextEpoch)
best_loss = epoch_loss
best_epoch = epoch
else:
print("Stopping due to lack of improvement in prior epoch")
break
model = model.to(device)
print('Finished Training')