in python/hlo/optimize.py [0:0]
def maybe_enable_rtr_shuffle(self):
id_to_inst = {inst.id: inst for inst in self.entry_instructions}
parameter_insts = [inst for inst in self.entry_instructions if inst.opcode == 'parameter']
parameter_id_to_consumer_ids = {inst.id: [] for inst in parameter_insts}
for inst in self.entry_instructions:
if inst.operand_ids:
input_inst_id = inst.operand_ids[0]
input_inst = id_to_inst[input_inst_id]
if input_inst.opcode == 'parameter':
consumer_ids = parameter_id_to_consumer_ids[input_inst_id]
if inst.id not in consumer_ids:
consumer_ids.append(inst.id)
parameter_id_to_shuffle = {}
for inst in parameter_insts:
consumers = [id_to_inst[cid] for cid in parameter_id_to_consumer_ids[inst.id]]
rewriters = [get_rtr_rewriter(cinst, id_to_inst, self.output_ids) for cinst in consumers]
if rewriters and all(rwt is not None for rwt in rewriters):
if not _all_arrays_equal(rwt.shuffle_indices for rwt in rewriters):
continue
kernel_id_to_kernel_rewriters = {cinst.operand_ids[1]: [] for cinst in consumers}
for rwt in rewriters:
container = kernel_id_to_kernel_rewriters[rwt.kernel_inst.id]
if rwt not in container:
container.append(rwt)
kernel_rewriters_inconsistent = False
for kernel_rewriters in kernel_id_to_kernel_rewriters.values():
if not _all_arrays_equal(rwt.kernel_array_prtr for rwt in kernel_rewriters):
kernel_rewriters_inconsistent = True
break
if kernel_rewriters_inconsistent:
continue
shuffle_indices = rewriters[0].rewrite_input(inst)
for kernel_id, kernel_rewriters in kernel_id_to_kernel_rewriters.items():
if kernel_rewriters:
kernel_inst = id_to_inst[kernel_id]
kernel_rewriters[0].rewrite_kernel(kernel_inst)
parameter_id_to_shuffle[inst.id] = shuffle_indices
input_shuffles = []
for name in self.hlo_module.host_program_shape.parameter_names:
parameter_id = self.parameter_name_to_id[name]
shuffle = parameter_id_to_shuffle.get(parameter_id, None)
input_shuffles.append(shuffle)
if any(shuffle is not None for shuffle in input_shuffles):
self.input_shuffles = input_shuffles
# change host program shape and entry compuation program shape as well
self._reestablish_program_shapes()