def main()

in style-transfer/Training/StyleTransferTraining/src/train.py [0:0]


def main():
    global options, device

    # Get the ENV context
    script_dir = os.path.dirname(__file__)
    env = os.environ.copy()
    
    # Set the input folder
    input_dir = os.path.expanduser(options.input_dir) if options.input_dir \
        else os.path.join(script_dir, '..', 'data')
    vgg_path = os.path.join(input_dir, 'vgg', 'imagenet-vgg-verydeep-19.mat')
    coco_dir = os.path.join(input_dir, 'train')
    if not os.path.isdir(input_dir):
        fail('Failed to find the input folder at ' + input_dir)
    if not os.path.isfile(vgg_path):
        error('Failed to find the VGG model file at ' + vgg_path)
        fail('Please download it from http://www.vlfeat.org/matconvnet/models/beta16/imagenet-vgg-verydeep-19.mat')
    if not os.path.isdir(coco_dir):
        error('Failed to find the COCO 2014 training images in ' + coco_dir)
        fail('Plese download it from http://images.cocodataset.org/zips/train2014.zip')

    # Set the output folder
    output_dir = os.path.expanduser(options.output_dir) if options.output_dir \
        else env.get('OUTPUT_DIR', os.path.join(script_dir, '..', 'output'))
    model_dir = os.path.join(output_dir, 'checkpoint')
    if os.path.isdir(output_dir):
        if not os.path.isdir(model_dir):
            info('Creating a folder to store checkpoint at ' + model_dir)
            os.makedirs(model_dir)
    else:
        info('Creating a folder to store checkpoint at ' + model_dir)
        os.makedirs(model_dir)
    
    # Set the TensorBoard folder
    log_dir = os.path.expanduser(options.log_dir) if options.log_dir \
        else env.get('LOG_DIR', os.path.join(script_dir, '..', 'log'))
    if not os.path.isdir(log_dir):
        info('Creating a folder to store TensorBoard events at ' + log_dir)
        os.makedirs(log_dir)
    
    # Set the style image path
    style_path = os.path.expanduser(options.style_image) if os.path.isfile(options.style_image) \
        else os.path.join(input_dir, 'style_images', options.style_image)
    style_name = os.path.basename(os.path.splitext(style_path)[0])
    ckpt_path = os.path.join(model_dir, style_name + '.ckpt')
    if not os.path.isfile(style_path):
        fail('Failed to find the style image at ' + style_path)

    # Set hyper parameters
    batch_size = options.batch_size
    epochs = options.epoch
    lr = options.lr
    lambda_tv = options.lambda_tv
    lambda_feat = options.lambda_feat
    lambda_style = options.lambda_style
    
    # Print parsed arguments
    info('--------- Training parameters -------->')
    info('Style image path: ' + style_path)
    info('VGG model path: ' + vgg_path)
    info('Training image dir: ' + coco_dir)
    info('Checkpoint path: ' + ckpt_path)
    info('TensorBoard log dir: ' + log_dir)
    info('Training device: ' + device)
    info('Batch size: %d' % batch_size)
    info('Epoch count: %d' % epochs)
    info('Learning rate: ' + str(lr))
    info('Lambda tv: ' + str(lambda_tv))
    info('Lambda feat: ' + str(lambda_feat))
    info('Lambda style: ' + str(lambda_style))
    info('<-------- Training parameters ---------')

    # COCO images to train
    content_targets = list_jpgs(coco_dir)
    if len(content_targets) % batch_size != 0:
        content_targets = content_targets[:-(len(content_targets) % batch_size)]
    info('Total training data size: %d' % len(content_targets))
    
    # Image shape
    image_shape = (224, 224, 3)
    batch_shape = (batch_size,) + image_shape
    
    # Style target
    style_target = read_img(style_path)
    style_shape = (1,) + style_target.shape
    
    with tf.device(device), tf.Session() as sess:
        # Compute gram maxtrix of style target
        style_image = tf.placeholder(tf.float32, shape=style_shape, name='style_image')
        vggstyletarget = vgg.net(vgg_path, vgg.preprocess(style_image))
        style_vgg = vgg.get_style_vgg(vggstyletarget, style_image, np.array([style_target]))
        
        # Content target feature 
        content_vgg = {}
        inputs = tf.placeholder(tf.float32, shape=batch_shape, name='inputs')
        content_net = vgg.net(vgg_path, vgg.preprocess(inputs))
        content_vgg['relu4_2'] = content_net['relu4_2']
        
        # Feature after transformation
        outputs = stylenet.net(inputs / 255.0)
        vggoutputs = vgg.net(vgg_path, vgg.preprocess(outputs))
        
        # Compute feature loss
        loss_f = options.lambda_feat * vgg.total_content_loss(vggoutputs, content_vgg, batch_size)
        
        # Compute style loss        
        loss_s = options.lambda_style * vgg.total_style_loss(vggoutputs, style_vgg, batch_size)
        
        # Total variation denoising
        loss_tv = options.lambda_tv * vgg.total_variation_regularization(outputs, batch_size, batch_shape)
        
        # Total loss
        total_loss = loss_f + loss_s + loss_tv
        train_step = tf.train.AdamOptimizer(options.lr).minimize(total_loss)
        
        # Create summary
        tf.summary.scalar('loss', total_loss)
        merged = tf.summary.merge_all()
        
        # Used to save model
        saver = tf.train.Saver()
    
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        # Restore checkpoint if available
        sess.run(tf.global_variables_initializer())
        ckpt = tf.train.get_checkpoint_state(model_dir)
        if ckpt and ckpt.model_checkpoint_path:
            info('Restoring from ' + ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
        
        # Write the graph
        writer = tf.summary.FileWriter(log_dir, sess.graph)
        
        # Start to train
        total_step = 0
        for epoch in range(epochs):
            info('epoch: %d' % epoch)
            step = 0
            while step * batch_size < len(content_targets):
                time_start = time.time()
                
                # Load one batch
                batch = np.zeros(batch_shape, dtype=np.float32)
                for i, img in enumerate(content_targets[step * batch_size : (step + 1) * batch_size]):
                    batch[i] = read_img(img, image_shape).astype(np.float32) # (224,224,3)
                
                # Proceed one step
                step += 1
                total_step += 1
                _, loss, summary = sess.run([train_step, total_loss, merged], feed_dict= {inputs: batch})
             
                time_elapse = time.time() - time_start
                if total_step % 5 == 0:
                    info('[step {}] elapse time: {} loss: {}'.format(total_step, time_elapse, loss))
                    writer.add_summary(summary, total_step)
                
                # Write checkpoint
                if total_step % 2000 == 0:                                        
                    info('Saving checkpoint to ' + ckpt_path)
                    saver.save(sess, ckpt_path, global_step=total_step)
        
        info('Saving final checkpoint to ' + ckpt_path)
        saver.save(sess, ckpt_path, global_step=total_step)