in train.py [0:0]
def main(args):
# create model directory
if not os.path.exists(args.model_path):
os.makedirs(args.model_path)
# image preprocessing
transform = transforms.Compose([
transforms.Resize(args.crop_size),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
# load vocab wrapper
with open(args.vocab_path, 'rb') as f:
vocab = pickle.load(f)
print ("cluster sizes: ", vocab.get_shapes())
with open(args.annotation_path, 'rb') as f:
annotation = pickle.load(f)
print ("annotations size:", len(annotation))
# build data loader
data_loader = get_loader(annotation, args.image_dir, args.h_dir, args.openpose_dir, vocab, transform,
args.batch_size, shuffle=True, num_workers=args.num_workers, seq_length=args.seq_length)
upp_size, low_size = vocab.get_shapes()
encoder = EncoderCNN(args.embed_size).to(device)
if args.upp:
decoder = DecoderRNN(args.embed_size,
args.hidden_size,
upp_size+1,
args.num_layers).to(device)
elif args.low:
decoder = DecoderRNN(args.embed_size,
args.hidden_size,
low_size+1,
args.num_layers).to(device)
else:
print('Please specify upper/lower body model to train')
exit(0)
# loss and optimizer
criterion = nn.CrossEntropyLoss()
params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
optimizer = torch.optim.Adam(params, lr=args.learning_rate)
# train the models
total_step = len(data_loader)
print ("total iter", total_step)
for epoch in range(args.num_epochs):
for i, (images, poses, homography, poses2, lengths) in enumerate(data_loader):
images = images.to(device)
poses = poses.to(device)
targets = pack_padded_sequence(poses, lengths, batch_first=True)[0]
# forward, backward, optimize
features = encoder(images)
outputs = decoder(features, homography, poses2, lengths)
loss = criterion(outputs, targets)
decoder.zero_grad()
encoder.zero_grad()
loss.backward()
optimizer.step()
if i % args.log_step == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
.format(epoch, args.num_epochs, i, total_step, loss.item(), np.exp(loss.item())))
if ((i+1) % args.save_step == 0) or (i == total_step-1):
torch.save(decoder.state_dict(), os.path.join(args.model_path, 'decoder-{}-{}.ckpt'.format(epoch+1, i+1)))
torch.save(encoder.state_dict(), os.path.join(args.model_path, 'encoder-{}-{}.ckpt'.format(epoch+1, i+1)))