def main()

in train.py [0:0]


def main(args):
	# create model directory
	if not os.path.exists(args.model_path):
		os.makedirs(args.model_path)

	# image preprocessing
	transform = transforms.Compose([
		transforms.Resize(args.crop_size),
		transforms.ToTensor(),
		transforms.Normalize((0.485, 0.456, 0.406),
			(0.229, 0.224, 0.225))])

	# load vocab wrapper
	with open(args.vocab_path, 'rb') as f:
		vocab = pickle.load(f)
	print ("cluster sizes: ", vocab.get_shapes())

	with open(args.annotation_path, 'rb') as f:
		annotation = pickle.load(f)
	print ("annotations size:", len(annotation))
	
	# build data loader
	data_loader = get_loader(annotation, args.image_dir, args.h_dir, args.openpose_dir, vocab, transform, 
		args.batch_size, shuffle=True, num_workers=args.num_workers, seq_length=args.seq_length)

	upp_size, low_size = vocab.get_shapes()
	encoder = EncoderCNN(args.embed_size).to(device)

	if args.upp:
		decoder = DecoderRNN(args.embed_size, 
							 args.hidden_size, 
							 upp_size+1,
							 args.num_layers).to(device)
	elif args.low:
		decoder = DecoderRNN(args.embed_size, 
							 args.hidden_size, 
							 low_size+1,
							 args.num_layers).to(device)
	else:
		print('Please specify upper/lower body model to train')
		exit(0)

	# loss and optimizer
	criterion = nn.CrossEntropyLoss()
	params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
	optimizer = torch.optim.Adam(params, lr=args.learning_rate)

	# train the models
	total_step = len(data_loader)
	print ("total iter", total_step)
	for epoch in range(args.num_epochs):
		for i, (images, poses, homography, poses2, lengths) in enumerate(data_loader):
			images = images.to(device)
			poses = poses.to(device)
			targets = pack_padded_sequence(poses, lengths, batch_first=True)[0]

			# forward, backward, optimize
			features = encoder(images)
			outputs = decoder(features, homography, poses2, lengths)
			loss = criterion(outputs, targets)
			decoder.zero_grad()
			encoder.zero_grad()
			loss.backward()
			optimizer.step()

			if i % args.log_step == 0:
				print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
					.format(epoch, args.num_epochs, i, total_step, loss.item(), np.exp(loss.item())))

			if ((i+1) % args.save_step == 0) or (i == total_step-1):
				torch.save(decoder.state_dict(), os.path.join(args.model_path, 'decoder-{}-{}.ckpt'.format(epoch+1, i+1)))
				torch.save(encoder.state_dict(), os.path.join(args.model_path, 'encoder-{}-{}.ckpt'.format(epoch+1, i+1)))