in src/model.py [0:0]
def split_states(x, n): """Reshape the last dimension of x into [n, x.shape[-1]/n].""" *start, m = shape_list(x) return tf.reshape(x, start + [n, m//n])