in python/hlo/optimize.py [0:0]
def batchify_reshape_dot_reshape(self):
# rewrite (batch) -> reshape -> dot -> reshape -> (batch) with dot to enable batch analyzer
hlo_op_list = [HloOp(inst) for inst in self.entry_instructions]
id_to_op = {op.id: op for op in hlo_op_list}
for op in hlo_op_list:
op.consumer_ids = []
for op in hlo_op_list:
for oid in op.operand_ids:
oop = id_to_op[oid]
if op.id not in oop.consumer_ids:
oop.consumer_ids.append(op.id)
reshape_id_to_dot_reshape_list = {}
for op in hlo_op_list:
if op.opcode == 'reshape' and op.id not in self.output_ids:
dot_op = id_to_op[op.operand_ids[0]]
if dot_op.opcode != 'dot' or dot_op.id in self.output_ids:
continue
if len(dot_op.consumer_ids) != 1:
continue
lhs_id, rhs_id = dot_op.operand_ids
reshape_op = id_to_op[lhs_id]
rhs_op = id_to_op[rhs_id]
ddn = dot_op.inst.dot_dimension_numbers
if not (len(reshape_op.shape) == 2 and ddn.lhs_contracting_dimensions == [1]):
continue
if not (len(rhs_op.shape) == 2 and ddn.rhs_contracting_dimensions == [0]):
continue
if reshape_op.opcode != 'reshape' or reshape_op.id in self.output_ids:
continue
if np.prod(reshape_op.shape[:-1]) != np.prod(op.shape[:-1]):
continue
input_op = id_to_op[reshape_op.operand_ids[0]]
if input_op.shape[0] != op.shape[0]: # need same input/output batch sizes
continue
if reshape_op.id not in reshape_id_to_dot_reshape_list:
reshape_id_to_dot_reshape_list[reshape_op.id] = []
reshape_id_to_dot_reshape_list[reshape_op.id].append([dot_op, op])
for reshape_id, dot_reshape_list in reshape_id_to_dot_reshape_list.items():
input_reshape_op = id_to_op[reshape_id]
if len(input_reshape_op.consumer_ids) != len(dot_reshape_list) or not dot_reshape_list:
continue
_, output_reshape_op = dot_reshape_list[0]
if any(op.shape[:-1] != output_reshape_op.shape[:-1] for _, op in dot_reshape_list):
continue
input_shape = input_reshape_op.inst.shape
input_shape.dimensions[:-1] = output_reshape_op.shape[:-1]
input_shape.is_dynamic_dimension[:] = [False for _ in input_reshape_op.shape]
input_shape.layout.minor_to_major[:] = reversed(range(len(input_reshape_op.shape)))
for dot_op, output_reshape_op in dot_reshape_list:
output_nd = len(output_reshape_op.shape)
dot_shape = dot_op.inst.shape
dot_shape.dimensions[:] = output_reshape_op.shape
dot_shape.is_dynamic_dimension[:] = [False for _ in output_reshape_op.shape]
dot_shape.layout.minor_to_major[:] = reversed(range(output_nd))
dot_op.inst.dot_dimension_numbers.lhs_contracting_dimensions[:] = [output_nd - 1]