import find_mxnet
import mxnet as mx
import numpy as np
import importlib
import logging
logging.basicConfig(level=logging.DEBUG)
import argparse
from collections import namedtuple
from skimage import io, transform
from skimage.restoration import denoise_tv_chambolle

CallbackData = namedtuple('CallbackData', field_names=['eps','epoch','img','filename'])

def get_args(arglist=None):
    parser = argparse.ArgumentParser(description='neural style')

    parser.add_argument('--model', type=str, default='vgg19',
                        choices = ['vgg'],
                        help = 'the pretrained model to use')
    parser.add_argument('--content-image', type=str, default='input/IMG_4343.jpg',
                        help='the content image')
    parser.add_argument('--style-image', type=str, default='input/starry_night.jpg',
                        help='the style image')
    parser.add_argument('--stop-eps', type=float, default=.005,
                        help='stop if the relative chanage is less than eps')
    parser.add_argument('--content-weight', type=float, default=10,
                        help='the weight for the content image')
    parser.add_argument('--style-weight', type=float, default=1,
                        help='the weight for the style image')
    parser.add_argument('--tv-weight', type=float, default=1e-2,
                        help='the magtitute on TV loss')
    parser.add_argument('--max-num-epochs', type=int, default=1000,
                        help='the maximal number of training epochs')
    parser.add_argument('--max-long-edge', type=int, default=600,
                        help='resize the content image')
    parser.add_argument('--lr', type=float, default=.001,
                        help='the initial learning rate')
    parser.add_argument('--gpu', type=int, default=0,
                        help='which gpu card to use, -1 means using cpu')
    parser.add_argument('--output_dir', type=str, default='output/',
                        help='the output image')
    parser.add_argument('--save-epochs', type=int, default=50,
                        help='save the output every n epochs')
    parser.add_argument('--remove-noise', type=float, default=.02,
                        help='the magtitute to remove noise')
    parser.add_argument('--lr-sched-delay', type=int, default=75,
                        help='how many epochs between decreasing learning rate')
    parser.add_argument('--lr-sched-factor', type=int, default=0.9,
                        help='factor to decrease learning rate on schedule')

    if arglist is None:
        return parser.parse_args()
    else:
        return parser.parse_args(arglist)


def PreprocessContentImage(path, long_edge):
    img = io.imread(path)
    logging.info("load the content image, size = %s", img.shape[:2])
    factor = float(long_edge) / max(img.shape[:2])
    new_size = (int(img.shape[0] * factor), int(img.shape[1] * factor))
    resized_img = transform.resize(img, new_size)
    sample = np.asarray(resized_img) * 256
    # swap axes to make image from (224, 224, 3) to (3, 224, 224)
    sample = np.swapaxes(sample, 0, 2)
    sample = np.swapaxes(sample, 1, 2)
    # sub mean
    sample[0, :] -= 123.68
    sample[1, :] -= 116.779
    sample[2, :] -= 103.939
    logging.info("resize the content image to %s", new_size)
    return np.resize(sample, (1, 3, sample.shape[1], sample.shape[2]))

def PreprocessStyleImage(path, shape):
    img = io.imread(path)
    resized_img = transform.resize(img, (shape[2], shape[3]))
    sample = np.asarray(resized_img) * 256
    sample = np.swapaxes(sample, 0, 2)
    sample = np.swapaxes(sample, 1, 2)

    sample[0, :] -= 123.68
    sample[1, :] -= 116.779
    sample[2, :] -= 103.939
    return np.resize(sample, (1, 3, sample.shape[1], sample.shape[2]))

def PostprocessImage(img):
    img = np.resize(img, (3, img.shape[2], img.shape[3]))
    img[0, :] += 123.68
    img[1, :] += 116.779
    img[2, :] += 103.939
    img = np.swapaxes(img, 1, 2)
    img = np.swapaxes(img, 0, 2)
    img = np.clip(img, 0, 255)
    return img.astype('uint8')

def SaveImage(img, filename, remove_noise=0.):
    logging.info('save output to %s', filename)
    out = PostprocessImage(img)
    if remove_noise != 0.0:
        out = denoise_tv_chambolle(out, weight=remove_noise, multichannel=True)
    io.imsave(filename, out)

def style_gram_symbol(input_size, style):
    _, output_shapes, _ = style.infer_shape(data=(1, 3, input_size[0], input_size[1]))
    gram_list = []
    grad_scale = []
    for i in range(len(style.list_outputs())):
        shape = output_shapes[i]
        x = mx.sym.Reshape(style[i], target_shape=(int(shape[1]), int(np.prod(shape[2:]))))
        # use fully connected to quickly do dot(x, x^T)
        gram = mx.sym.FullyConnected(x, x, no_bias=True, num_hidden=shape[1])
        gram_list.append(gram)
        grad_scale.append(np.prod(shape[1:]) * shape[1])
    return mx.sym.Group(gram_list), grad_scale


def get_loss(gram, content):
    gram_loss = []
    for i in range(len(gram.list_outputs())):
        gvar = mx.sym.Variable("target_gram_%d" % i)
        gram_loss.append(mx.sym.sum(mx.sym.square(gvar - gram[i])))
    cvar = mx.sym.Variable("target_content")
    content_loss = mx.sym.sum(mx.sym.square(cvar - content))
    return mx.sym.Group(gram_loss), content_loss

def get_tv_grad_executor(img, ctx, tv_weight):
    """create TV gradient executor with input binded on img
    """
    if tv_weight <= 0.0:
        return None
    nchannel = img.shape[1]
    simg = mx.sym.Variable("img")
    skernel = mx.sym.Variable("kernel")
    channels = mx.sym.SliceChannel(simg, num_outputs=nchannel)
    out = mx.sym.Concat(*[
        mx.sym.Convolution(data=channels[i], weight=skernel,
                           num_filter=1,
                           kernel=(3, 3), pad=(1,1),
                           no_bias=True, stride=(1,1))
        for i in range(nchannel)])
    kernel = mx.nd.array(np.array([[0, -1, 0],
                                   [-1, 4, -1],
                                   [0, -1, 0]])
                         .reshape((1, 1, 3, 3)),
                         ctx) / 8.0
    out = out * tv_weight
    return out.bind(ctx, args={"img": img,
                               "kernel": kernel})

def train_nstyle(args, callback=None):
    """Train a neural style network.
    Args are from argparse and control input, output, hyper-parameters.
    callback allows for display of training progress.
    """
    # input
    dev = mx.gpu(args.gpu) if args.gpu >= 0 else mx.cpu()
    content_np = PreprocessContentImage(args.content_image, args.max_long_edge)
    style_np = PreprocessStyleImage(args.style_image, shape=content_np.shape)
    size = content_np.shape[2:]

    # model
    Executor = namedtuple('Executor', ['executor', 'data', 'data_grad'])

    model_module =  importlib.import_module('model_' + args.model)
    style, content = model_module.get_symbol()
    gram, gscale = style_gram_symbol(size, style)
    model_executor = model_module.get_executor(gram, content, size, dev)
    model_executor.data[:] = style_np
    model_executor.executor.forward()
    style_array = []
    for i in range(len(model_executor.style)):
        style_array.append(model_executor.style[i].copyto(mx.cpu()))

    model_executor.data[:] = content_np
    model_executor.executor.forward()
    content_array = model_executor.content.copyto(mx.cpu())

    # delete the executor
    del model_executor

    style_loss, content_loss = get_loss(gram, content)
    model_executor = model_module.get_executor(
        style_loss, content_loss, size, dev)

    grad_array = []
    for i in range(len(style_array)):
        style_array[i].copyto(model_executor.arg_dict["target_gram_%d" % i])
        grad_array.append(mx.nd.ones((1,), dev) * (float(args.style_weight) / gscale[i]))
    grad_array.append(mx.nd.ones((1,), dev) * (float(args.content_weight)))

    print([x.asscalar() for x in grad_array])
    content_array.copyto(model_executor.arg_dict["target_content"])

    # train
    # initialize img with random noise
    img = mx.nd.zeros(content_np.shape, ctx=dev)
    img[:] = mx.rnd.uniform(-0.1, 0.1, img.shape)

    lr = mx.lr_scheduler.FactorScheduler(step=args.lr_sched_delay, 
            factor=args.lr_sched_factor)

    optimizer = mx.optimizer.NAG(
        learning_rate = args.lr,
        wd = 0.0001,
        momentum=0.95,
        lr_scheduler = lr)
    optim_state = optimizer.create_state(0, img)

    logging.info('start training arguments %s', args)
    old_img = img.copyto(dev)
    clip_norm = 1 * np.prod(img.shape)
    tv_grad_executor = get_tv_grad_executor(img, dev, args.tv_weight)

    for e in range(args.max_num_epochs):
        img.copyto(model_executor.data)
        model_executor.executor.forward()
        model_executor.executor.backward(grad_array)
        gnorm = mx.nd.norm(model_executor.data_grad).asscalar()
        if gnorm > clip_norm:
            model_executor.data_grad[:] *= clip_norm / gnorm

        if tv_grad_executor is not None:
            tv_grad_executor.forward()
            optimizer.update(0, img,
                             model_executor.data_grad + tv_grad_executor.outputs[0],
                             optim_state)
        else:
            optimizer.update(0, img, model_executor.data_grad, optim_state)
        new_img = img
        eps = (mx.nd.norm(old_img - new_img) / mx.nd.norm(new_img)).asscalar()

        old_img = new_img.copyto(dev)
        logging.info('epoch %d, relative change %f', e, eps)
        if eps < args.stop_eps:
            logging.info('eps < args.stop_eps, training finished')
            break

        if callback:
            cbdata = {
                'eps': eps,
                'epoch': e+1,
            }
        if (e+1) % args.save_epochs == 0:
            outfn = args.output_dir + 'e_'+str(e+1)+'.jpg'
            npimg = new_img.asnumpy()
            SaveImage(npimg, outfn, args.remove_noise)
            if callback:
                cbdata['filename'] = outfn
                cbdata['img'] = npimg
        if callback:
            callback(cbdata)

    final_fn = args.output_dir + '/final.jpg'
    SaveImage(new_img.asnumpy(), final_fn)


if __name__ == "__main__":
    args = get_args()
    train_nstyle(args)

