in python/hlo/optimize.py [0:0]
def maybe_rewrite_batch_size(self):
batch_size = self.get_batch_size()
if batch_size is None:
return
# disallow rewriting if some ops have unknown batch semantics
hlo_op_list = [HloOp(inst) for inst in self.entry_instructions]
id_to_op = {op.id: op for op in hlo_op_list}
non_batch_ids = set()
for op in hlo_op_list:
op.input_shapes = [id_to_op[oid].shape for oid in op.operand_ids]
op.propagable_ids = []
op.is_non_batch = False
batched_ids = {op.id for op in self.hlo_ops_with_batch_axis}
for op in hlo_op_list:
func = getattr(BatchHloInstructionPool, op.legal_opcode, None)
if func is not None:
batch_definition = func(op)
if batch_definition is None:
batch_definition = [op.id], None
op.propagable_ids, _ = batch_definition
id_oids = [op.id, *op.operand_ids]
non_batch_ids.update(i for i in id_oids if i not in op.propagable_ids)
non_batch_ids.difference_update(batched_ids)
for op in hlo_op_list:
for oid in op.operand_ids:
oop = id_to_op[oid]
if oop.id in op.propagable_ids and op.id not in oop.propagable_ids:
oop.propagable_ids.append(op.id)
for op in hlo_op_list:
if op.id in non_batch_ids:
op.is_non_batch = True
visited_ids = set()
stack = [op for op in hlo_op_list if op.is_non_batch]
while stack:
non_batch_op = stack.pop()
if non_batch_op.id in visited_ids:
continue
visited_ids.add(non_batch_op.id)
non_batch_op.is_non_batch = True
stack.extend(id_to_op[pid] for pid in non_batch_op.propagable_ids)
non_batched_ids = {op.id for op in hlo_op_list if op.is_non_batch}
all_analyzed_ids = batched_ids.union(non_batched_ids)
all_analyzed_ids.add(self.output_tuple_op.id)
if len(all_analyzed_ids) != len(self.entry_instructions):
return
if batched_ids.intersection(non_batched_ids):
return
# rewrite if IO buffer demand or cache demand is too high
inputs, outputs = self.get_io_tensors()
total_io_num_bytes = sum(ts.num_bytes for ts in inputs + outputs)
io_queue_num_bytes = total_io_num_bytes * _DEFAULT_IO_QUEUE_DEPTH
io_queue_too_large = io_queue_num_bytes > _DEFAULT_IO_BUFFER_NUM_BYTES
cache_demand_num_bytes = self.estimate_cache_demand()
if cache_demand_num_bytes is None:
cache_demand_too_high = False
else:
cache_demand_too_high = cache_demand_num_bytes > 2 * _DEFAULT_CACHE_CAPACITY
if not (io_queue_too_large or cache_demand_too_high):
return
bytes_to_mbytes = lambda num_bytes: int(num_bytes / 1024 / 1024)
if batch_size == 1:
return
if io_queue_too_large:
num_mb = bytes_to_mbytes(io_queue_num_bytes)
reason = 'batch size {} would require {} MB IO buffer'.format(batch_size, num_mb)
elif cache_demand_too_high:
num_mb = bytes_to_mbytes(cache_demand_num_bytes)
reason = 'batch size {} would create {} MB cache demand'.format(batch_size, num_mb)
logger.warning('{}; rewriting batch size to mitigate'.format(reason))
self.rewrite_batch_size(1)
# estimate from IO queue size
inputs, outputs = self.get_io_tensors()
total_io_num_bytes = sum(ts.num_bytes for ts in inputs + outputs)
io_queue_num_bytes = total_io_num_bytes * _DEFAULT_IO_QUEUE_DEPTH
batch_size_from_io = round(_DEFAULT_IO_BUFFER_NUM_BYTES / io_queue_num_bytes)
# estimate from cache demand
cache_demand = self.estimate_cache_demand()
batch_size_from_cache = round(_DEFAULT_CACHE_CAPACITY / cache_demand)
# choose the smaller estimate
batch_size = min(batch_size, batch_size_from_cache, batch_size_from_io)
batch_size = max(batch_size, 1)
reason = 'IO queue size' if batch_size_from_io < batch_size_from_cache else 'cache demand'
if batch_size > 64:
batch_size = batch_size // 64 * 64
else:
batch_size = 2 ** (batch_size.bit_length() - 1)
logger.info('estimated optimal batch size {} from {}'.format(batch_size, reason))
self.rewrite_batch_size(batch_size, final=True)
# change input_shuffles to new batch size
def change_batch_size(shuffle):
if shuffle is None:
return shuffle
else:
return shuffle.reshape([batch_size, -1])[:batch_size].ravel()
if self.input_shuffles is not None:
self.input_shuffles = [change_batch_size(shuffle) for shuffle in self.input_shuffles]
# change host program shape and entry compuation program shape as well
self._reestablish_program_shapes()
self._legalize_instructions()