def maybe_enable_rtr_shuffle()

in python/hlo/optimize.py [0:0]


    def maybe_enable_rtr_shuffle(self):
        id_to_inst = {inst.id: inst for inst in self.entry_instructions}
        parameter_insts = [inst for inst in self.entry_instructions if inst.opcode == 'parameter']
        parameter_id_to_consumer_ids = {inst.id: [] for inst in parameter_insts}
        for inst in self.entry_instructions:
            if inst.operand_ids:
                input_inst_id = inst.operand_ids[0]
                input_inst = id_to_inst[input_inst_id]
                if input_inst.opcode == 'parameter':
                    consumer_ids = parameter_id_to_consumer_ids[input_inst_id]
                    if inst.id not in consumer_ids:
                        consumer_ids.append(inst.id)
        parameter_id_to_shuffle = {}
        for inst in parameter_insts:
            consumers = [id_to_inst[cid] for cid in parameter_id_to_consumer_ids[inst.id]]
            rewriters = [get_rtr_rewriter(cinst, id_to_inst, self.output_ids) for cinst in consumers]
            if rewriters and all(rwt is not None for rwt in rewriters):
                if not _all_arrays_equal(rwt.shuffle_indices for rwt in rewriters):
                    continue
                kernel_id_to_kernel_rewriters = {cinst.operand_ids[1]: [] for cinst in consumers}
                for rwt in rewriters:
                    container = kernel_id_to_kernel_rewriters[rwt.kernel_inst.id]
                    if rwt not in container:
                        container.append(rwt)
                kernel_rewriters_inconsistent = False
                for kernel_rewriters in kernel_id_to_kernel_rewriters.values():
                    if not _all_arrays_equal(rwt.kernel_array_prtr for rwt in kernel_rewriters):
                        kernel_rewriters_inconsistent = True
                        break
                if kernel_rewriters_inconsistent:
                    continue
                shuffle_indices = rewriters[0].rewrite_input(inst)
                for kernel_id, kernel_rewriters in kernel_id_to_kernel_rewriters.items():
                    if kernel_rewriters:
                        kernel_inst = id_to_inst[kernel_id]
                        kernel_rewriters[0].rewrite_kernel(kernel_inst)
                parameter_id_to_shuffle[inst.id] = shuffle_indices
        input_shuffles = []
        for name in self.hlo_module.host_program_shape.parameter_names:
            parameter_id = self.parameter_name_to_id[name]
            shuffle = parameter_id_to_shuffle.get(parameter_id, None)
            input_shuffles.append(shuffle)
        if any(shuffle is not None for shuffle in input_shuffles):
            self.input_shuffles = input_shuffles

        # change host program shape and entry compuation program shape as well
        self._reestablish_program_shapes()