def maybe_rewrite_batch_size()

in python/hlo/optimize.py [0:0]


    def maybe_rewrite_batch_size(self):
        batch_size = self.get_batch_size()
        if batch_size is None:
            return

        # disallow rewriting if some ops have unknown batch semantics
        hlo_op_list = [HloOp(inst) for inst in self.entry_instructions]
        id_to_op = {op.id: op for op in hlo_op_list}
        non_batch_ids = set()
        for op in hlo_op_list:
            op.input_shapes = [id_to_op[oid].shape for oid in op.operand_ids]
            op.propagable_ids = []
            op.is_non_batch = False
        batched_ids = {op.id for op in self.hlo_ops_with_batch_axis}
        for op in hlo_op_list:
            func = getattr(BatchHloInstructionPool, op.legal_opcode, None)
            if func is not None:
                batch_definition = func(op)
                if batch_definition is None:
                    batch_definition = [op.id], None
                op.propagable_ids, _ = batch_definition
                id_oids = [op.id, *op.operand_ids]
                non_batch_ids.update(i for i in id_oids if i not in op.propagable_ids)
        non_batch_ids.difference_update(batched_ids)
        for op in hlo_op_list:
            for oid in op.operand_ids:
                oop = id_to_op[oid]
                if oop.id in op.propagable_ids and op.id not in oop.propagable_ids:
                    oop.propagable_ids.append(op.id)
        for op in hlo_op_list:
            if op.id in non_batch_ids:
                op.is_non_batch = True
        visited_ids = set()
        stack = [op for op in hlo_op_list if op.is_non_batch]
        while stack:
            non_batch_op = stack.pop()
            if non_batch_op.id in visited_ids:
                continue
            visited_ids.add(non_batch_op.id)
            non_batch_op.is_non_batch = True
            stack.extend(id_to_op[pid] for pid in non_batch_op.propagable_ids)
        non_batched_ids = {op.id for op in hlo_op_list if op.is_non_batch}
        all_analyzed_ids = batched_ids.union(non_batched_ids)
        all_analyzed_ids.add(self.output_tuple_op.id)
        if len(all_analyzed_ids) != len(self.entry_instructions):
            return
        if batched_ids.intersection(non_batched_ids):
            return

        # rewrite if IO buffer demand or cache demand is too high
        inputs, outputs = self.get_io_tensors()
        total_io_num_bytes = sum(ts.num_bytes for ts in inputs + outputs)
        io_queue_num_bytes = total_io_num_bytes * _DEFAULT_IO_QUEUE_DEPTH
        io_queue_too_large = io_queue_num_bytes > _DEFAULT_IO_BUFFER_NUM_BYTES
        cache_demand_num_bytes = self.estimate_cache_demand()
        if cache_demand_num_bytes is None:
            cache_demand_too_high = False
        else:
            cache_demand_too_high = cache_demand_num_bytes > 2 * _DEFAULT_CACHE_CAPACITY
        if not (io_queue_too_large or cache_demand_too_high):
            return
        bytes_to_mbytes = lambda num_bytes: int(num_bytes / 1024 / 1024)
        if batch_size == 1:
            return
        if io_queue_too_large:
            num_mb = bytes_to_mbytes(io_queue_num_bytes)
            reason = 'batch size {} would require {} MB IO buffer'.format(batch_size, num_mb)
        elif cache_demand_too_high:
            num_mb = bytes_to_mbytes(cache_demand_num_bytes)
            reason = 'batch size {} would create {} MB cache demand'.format(batch_size, num_mb)
        logger.warning('{}; rewriting batch size to mitigate'.format(reason))
        self.rewrite_batch_size(1)

        # estimate from IO queue size
        inputs, outputs = self.get_io_tensors()
        total_io_num_bytes = sum(ts.num_bytes for ts in inputs + outputs)
        io_queue_num_bytes = total_io_num_bytes * _DEFAULT_IO_QUEUE_DEPTH
        batch_size_from_io = round(_DEFAULT_IO_BUFFER_NUM_BYTES / io_queue_num_bytes)

        # estimate from cache demand
        cache_demand = self.estimate_cache_demand()
        batch_size_from_cache = round(_DEFAULT_CACHE_CAPACITY / cache_demand)

        # choose the smaller estimate
        batch_size = min(batch_size, batch_size_from_cache, batch_size_from_io)
        batch_size = max(batch_size, 1)
        reason = 'IO queue size' if batch_size_from_io < batch_size_from_cache else 'cache demand'
        if batch_size > 64:
            batch_size = batch_size // 64 * 64
        else:
            batch_size = 2 ** (batch_size.bit_length() - 1)
        logger.info('estimated optimal batch size {} from {}'.format(batch_size, reason))
        self.rewrite_batch_size(batch_size, final=True)

        # change input_shuffles to new batch size
        def change_batch_size(shuffle):
            if shuffle is None:
                return shuffle
            else:
                return shuffle.reshape([batch_size, -1])[:batch_size].ravel()

        if self.input_shuffles is not None:
            self.input_shuffles = [change_batch_size(shuffle) for shuffle in self.input_shuffles]

        # change host program shape and entry compuation program shape as well
        self._reestablish_program_shapes()
        self._legalize_instructions()