def batchify_reshape_dot_reshape()

in python/hlo/optimize.py [0:0]


    def batchify_reshape_dot_reshape(self):
        # rewrite (batch) -> reshape -> dot -> reshape -> (batch) with dot to enable batch analyzer
        hlo_op_list = [HloOp(inst) for inst in self.entry_instructions]
        id_to_op = {op.id: op for op in hlo_op_list}
        for op in hlo_op_list:
            op.consumer_ids = []
        for op in hlo_op_list:
            for oid in op.operand_ids:
                oop = id_to_op[oid]
                if op.id not in oop.consumer_ids:
                    oop.consumer_ids.append(op.id)
        reshape_id_to_dot_reshape_list = {}
        for op in hlo_op_list:
            if op.opcode == 'reshape' and op.id not in self.output_ids:
                dot_op = id_to_op[op.operand_ids[0]]
                if dot_op.opcode != 'dot' or dot_op.id in self.output_ids:
                    continue
                if len(dot_op.consumer_ids) != 1:
                    continue
                lhs_id, rhs_id = dot_op.operand_ids
                reshape_op = id_to_op[lhs_id]
                rhs_op = id_to_op[rhs_id]
                ddn = dot_op.inst.dot_dimension_numbers
                if not (len(reshape_op.shape) == 2 and ddn.lhs_contracting_dimensions == [1]):
                    continue
                if not (len(rhs_op.shape) == 2 and ddn.rhs_contracting_dimensions == [0]):
                    continue
                if reshape_op.opcode != 'reshape' or reshape_op.id in self.output_ids:
                    continue
                if np.prod(reshape_op.shape[:-1]) != np.prod(op.shape[:-1]):
                    continue
                input_op = id_to_op[reshape_op.operand_ids[0]]
                if input_op.shape[0] != op.shape[0]:  # need same input/output batch sizes
                    continue
                if reshape_op.id not in reshape_id_to_dot_reshape_list:
                    reshape_id_to_dot_reshape_list[reshape_op.id] = []
                reshape_id_to_dot_reshape_list[reshape_op.id].append([dot_op, op])
        for reshape_id, dot_reshape_list in reshape_id_to_dot_reshape_list.items():
            input_reshape_op = id_to_op[reshape_id]
            if len(input_reshape_op.consumer_ids) != len(dot_reshape_list) or not dot_reshape_list:
                continue
            _, output_reshape_op = dot_reshape_list[0]
            if any(op.shape[:-1] != output_reshape_op.shape[:-1] for _, op in dot_reshape_list):
                continue
            input_shape = input_reshape_op.inst.shape
            input_shape.dimensions[:-1] = output_reshape_op.shape[:-1]
            input_shape.is_dynamic_dimension[:] = [False for _ in input_reshape_op.shape]
            input_shape.layout.minor_to_major[:] = reversed(range(len(input_reshape_op.shape)))
            for dot_op, output_reshape_op in dot_reshape_list:
                output_nd = len(output_reshape_op.shape)
                dot_shape = dot_op.inst.shape
                dot_shape.dimensions[:] = output_reshape_op.shape
                dot_shape.is_dynamic_dimension[:] = [False for _ in output_reshape_op.shape]
                dot_shape.layout.minor_to_major[:] = reversed(range(output_nd))
                dot_op.inst.dot_dimension_numbers.lhs_contracting_dimensions[:] = [output_nd - 1]