in src/model.py [0:0]
def merge_states(x): """Smash the last two dimensions of x into a single dimension.""" *start, a, b = shape_list(x) return tf.reshape(x, start + [a*b])