import numpy as np
from scipy import linalg as LA
import mxnet as mx
import argparse
import utils

def conv_vh_decomposition(model, args):    
  W = model.arg_params[args.layer+'_weight'].asnumpy()  
  N, C, y, x = W.shape
  b = model.arg_params[args.layer+'_bias'].asnumpy()  
  W = W.transpose((1,2,0,3)).reshape((C*y, -1))

  U, D, Q = np.linalg.svd(W, full_matrices=False)
  sqrt_D = LA.sqrtm(np.diag(D))
  K = args.K  
  V = U[:,:K].dot(sqrt_D[:K, :K])
  H = Q.T[:,:K].dot(sqrt_D[:K, :K])  
  V = V.T.reshape(K, C, y, 1)
  b_1 = np.zeros((K, ))
  H = H.reshape(N, x, 1, K).transpose((0,3,2,1))
  b_2 = b

  W1, b1, W2, b2 = V, b_1, H, b_2  
  def sym_handle(data, node):
    kernel = eval(node['param']['kernel'])      
    pad = eval(node['param']['pad'])            
    name = node['name']

    name1 = name + '_v'
    kernel1 = tuple((kernel[0], 1))
    pad1 = tuple((pad[0], 0))
    num_filter = W1.shape[0]
    sym1 = mx.symbol.Convolution(data=data, kernel=kernel1, pad=pad1, num_filter=num_filter, name=name1)

    name2 = name + '_h'
    kernel2 = tuple((1, kernel[1]))
    pad2 = tuple((0, pad[1]))
    num_filter = W2.shape[0]
    sym2 = mx.symbol.Convolution(data=sym1, kernel=kernel2, pad=pad2, num_filter=num_filter, name=name2)  
    return sym2

  def arg_handle(arg_shape_dic, arg_params):
    name1 = args.layer + '_v'
    name2 = args.layer + '_h'    
    weight1 = mx.ndarray.array(W1)
    bias1 = mx.ndarray.array(b1)    
    weight2 = mx.ndarray.array(W2)
    bias2 = mx.ndarray.array(b2)    
    assert weight1.shape == arg_shape_dic[name1+'_weight'], 'weight1'
    assert weight2.shape == arg_shape_dic[name2+'_weight'], 'weight2'
    assert bias1.shape == arg_shape_dic[name1+'_bias'], 'bias1'
    assert bias2.shape == arg_shape_dic[name2+'_bias'], 'bias2'

    arg_params[name1 + '_weight'] = weight1
    arg_params[name1 + '_bias'] = bias1
    arg_params[name2 + '_weight'] = weight2
    arg_params[name2 + '_bias'] = bias2
    
  new_model = utils.replace_conv_layer(args.layer, model, sym_handle, arg_handle)
  return new_model

def main():
  model = utils.load_model(args)  
  new_model = conv_vh_decomposition(model, args)
  new_model.save(args.save_model)

if __name__ == '__main__':
  parser=argparse.ArgumentParser()
  parser.add_argument('-m', '--model', help='the model to speed up')
  parser.add_argument('-g', '--gpus', default='0', help='the gpus to be used in ctx')
  parser.add_argument('--load-epoch',type=int,default=1)
  parser.add_argument('--layer')
  parser.add_argument('--K', type=int)
  parser.add_argument('--save-model')
  args = parser.parse_args()
  main()
  
