in python/hlo/optimize.py [0:0]
def maybe_enable_dynamic_batch_size(self):
hlo_op_list = [HloOp(inst) for inst in self.entry_instructions]
id_to_op = {op.id: op for op in hlo_op_list}
for op in hlo_op_list:
op.input_shapes = [id_to_op[oid].shape for oid in op.operand_ids]
op.consumer_ids = []
op.batch_propagable_ids = None
op.batch_propagable_neighbor_ids = []
op.batch_axis = None
op.batch_size_multiplier = 1
op.is_batch_axis_seed = False
op.batch_axis_source_ids = []
for op in hlo_op_list:
for oid in op.operand_ids:
oop = id_to_op[oid]
if op.id not in oop.consumer_ids:
oop.consumer_ids.append(op.id)
equivalent_source_ids = {}
for op in hlo_op_list:
func = getattr(BatchHloInstructionPool, op.legal_opcode, None)
if func is not None:
batch_definition = func(op)
if batch_definition is not None:
op.batch_propagable_ids, op.batch_axis = batch_definition
if op.batch_axis is not None:
op.is_batch_axis_seed = True
op.batch_axis_source_ids.append(op.id)
equivalent_source_ids[op.id] = {op.id}
for op in hlo_op_list:
if op.opcode == 'get-tuple-element':
if id_to_op[op.operand_ids[0]].opcode == 'batch-norm-training':
if op.inst.tuple_index == 0:
op.batch_propagable_ids = [op.id, *op.operand_ids]
def match_rtrcrt(op):
if op.opcode != 'transpose':
return None
out_r_op = id_to_op[op.operand_ids[0]]
if out_r_op.opcode != 'reshape':
return None
convolution_op = id_to_op[out_r_op.operand_ids[0]]
if convolution_op.opcode != 'convolution':
return None
conv_out_non_batch_shape = convolution_op.shape[1:]
if out_r_op.shape[-len(conv_out_non_batch_shape):] != conv_out_non_batch_shape:
return None
in_r_op = id_to_op[convolution_op.operand_ids[0]]
if in_r_op.opcode != 'reshape':
return None
in_t_op = id_to_op[in_r_op.operand_ids[0]]
if in_t_op.opcode != 'transpose':
return None
conv_in_non_batch_shape = in_r_op.shape[1:]
if in_t_op.shape[-len(conv_in_non_batch_shape):] != conv_in_non_batch_shape:
return None
if len(in_t_op.inst.dimensions) != len(op.inst.dimensions):
return None
batch_dim = 0
transposed_batch_dim = list(in_t_op.inst.dimensions).index(batch_dim)
if op.inst.dimensions[batch_dim] != transposed_batch_dim:
return None
r_op = id_to_op[in_t_op.operand_ids[0]]
if r_op.opcode != 'reshape':
return None
if r_op.batch_propagable_ids is None:
return None
return r_op, in_t_op, in_r_op, convolution_op, out_r_op, op
# reshape -> transpose -> reshape -> convolution -> reshape -> transpose pattern
for op in hlo_op_list:
rtrcrt = match_rtrcrt(op)
if rtrcrt is not None and op.batch_propagable_ids is None:
r_op, *_ = rtrcrt
op.batch_propagable_ids = [op.id, r_op.id]
r_op.batch_propagable_ids = [r_op.id, r_op.operand_ids[0], op.id]
# setup neighboring nodes
for op in hlo_op_list:
if op.batch_propagable_ids is not None:
op.batch_propagable_neighbor_ids.extend(op.batch_propagable_ids)
propagable_ops = [id_to_op[bpid] for bpid in op.batch_propagable_ids]
for pop in propagable_ops:
if op.id not in pop.batch_propagable_neighbor_ids:
pop.batch_propagable_neighbor_ids.append(op.id)
# propagate batch dimension information by traversing the graph
source_op_candidates = [op for op in hlo_op_list if op.is_batch_axis_seed]
source_id_to_root_id = {}
while source_op_candidates:
source_op = source_op_candidates.pop()
visited_ids = set()
stack = [source_op]
while stack:
current_op = stack.pop()
if current_op.id in visited_ids:
continue
visited_ids.add(current_op.id)
if current_op.batch_propagable_ids is not None:
propagable_ops = [id_to_op[bpid] for bpid in current_op.batch_propagable_ids]
for op in propagable_ops:
if op.is_batch_axis_seed:
equivalent_source_ids[source_op.id].update(op.batch_axis_source_ids)
else:
op.batch_axis = source_op.batch_axis
if source_op.id not in op.batch_axis_source_ids:
op.batch_axis_source_ids.append(source_op.id)
if current_op.legal_opcode != 'tuple':
# tuple should be a sink node that doesn't link different outputs together
stack.extend(id_to_op[oid] for oid in current_op.batch_propagable_neighbor_ids)
# remove transitively equivalent source candidates
visited_ids = set()
stack = list(equivalent_source_ids.keys())
while stack:
current_id = stack.pop()
if current_id in visited_ids:
continue
if current_id not in source_id_to_root_id:
source_id_to_root_id[current_id] = current_id
root_id = source_id_to_root_id[current_id]
for equivalent_id in equivalent_source_ids[current_id]:
source_id_to_root_id[equivalent_id] = root_id
root_id = source_id_to_root_id[source_op.id]
source_op_candidates = []
for op in source_op_candidates:
if source_id_to_root_id[op.batch_axis_source_id] != root_id:
source_op_candidates.append(op)
# enable dynamic batch size only if there is a single root source of batch dimension
parameter_ops = []
for name in self.hlo_module.host_program_shape.parameter_names:
parameter_id = self.parameter_name_to_id[name]
parameter_ops.append(id_to_op[parameter_id])
output_ops = [id_to_op[oid] for oid in self.output_tuple_op.operand_ids]
io_ops = parameter_ops + output_ops
io_ops_with_batch_axis = [op for op in io_ops if op.batch_axis is not None]
if not io_ops_with_batch_axis:
return
io_ops_root_ids = []
for op in io_ops_with_batch_axis:
root_ids = {source_id_to_root_id.get(sid, sid) for sid in op.batch_axis_source_ids}
io_ops_root_ids.append(root_ids)
if all(len(root_ids) == 1 for root_ids in io_ops_root_ids):
io_ops_source_id = io_ops_with_batch_axis[-1].batch_axis_source_ids[0]
io_ops_root_id = source_id_to_root_id.get(io_ops_source_id, io_ops_source_id)
reject = False
for op in hlo_op_list:
root_ids = {source_id_to_root_id.get(sid, sid) for sid in op.batch_axis_source_ids}
if io_ops_root_id in root_ids and len(root_ids) != 1:
# source conflict on instructions affecting IO instructions; reject
op.batch_axis = None
reject = True
if reject:
return
else:
# source conflict on IO instructions directly; reject
return
# TODO: this will be unnecessary once BatchHloInstructionPool is fully populated
if any(op.batch_axis is None for op in io_ops):
return
# reshape -> transpose -> reshape -> convolution -> reshape -> transpose pattern again
for op in hlo_op_list:
rtrcrt = match_rtrcrt(op)
if rtrcrt is not None:
if [o.batch_axis for o in rtrcrt] == [0, None, None, 0, None, 0]:
r_op, in_t_op, in_r_op, convolution_op, out_r_op, out_t_op = rtrcrt
in_t_op.batch_axis = list(in_t_op.inst.dimensions).index(0)
conv_in_non_batch_shape = in_r_op.shape[1:]
in_t_op_multiplier_shape = in_t_op.shape[:-len(conv_in_non_batch_shape)]
in_t_op_multiplier_shape.pop(in_t_op.batch_axis)
batch_size_multiplier = int(np.prod(in_t_op_multiplier_shape))
in_r_op.batch_axis = 0
in_r_op.batch_size_multiplier = batch_size_multiplier
convolution_op.batch_axis = 0
convolution_op.batch_size_multiplier = batch_size_multiplier
out_r_op.batch_axis = out_t_op.inst.dimensions[out_t_op.batch_axis]
# write input_batch_axis and output_batch_axis in runtime format
input_batch_axis = [op.batch_axis for op in parameter_ops]
output_batch_axis = [op.batch_axis for op in output_ops]
_assert_same_len(self.inputs, input_batch_axis, 'inputs', 'input_batch_axis')
_assert_same_len(self.outputs, output_batch_axis, 'outputs', 'output_batch_axis')
for args in [self.inputs, input_batch_axis], [self.outputs, output_batch_axis]:
for ts, axis in zip(*args):
ts.batch_axis = axis
self.hlo_ops_with_batch_axis = []
for op in hlo_op_list:
if op.batch_axis is not None:
if op.opcode not in HloOptimizer.tuple_output_opcodes:
self.hlo_ops_with_batch_axis.append(op)
self.original_batch_size = self.get_batch_size()