def maybe_enable_dynamic_batch_size()

in python/hlo/optimize.py [0:0]


    def maybe_enable_dynamic_batch_size(self):
        hlo_op_list = [HloOp(inst) for inst in self.entry_instructions]
        id_to_op = {op.id: op for op in hlo_op_list}
        for op in hlo_op_list:
            op.input_shapes = [id_to_op[oid].shape for oid in op.operand_ids]
            op.consumer_ids = []
            op.batch_propagable_ids = None
            op.batch_propagable_neighbor_ids = []
            op.batch_axis = None
            op.batch_size_multiplier = 1
            op.is_batch_axis_seed = False
            op.batch_axis_source_ids = []
        for op in hlo_op_list:
            for oid in op.operand_ids:
                oop = id_to_op[oid]
                if op.id not in oop.consumer_ids:
                    oop.consumer_ids.append(op.id)
        equivalent_source_ids = {}
        for op in hlo_op_list:
            func = getattr(BatchHloInstructionPool, op.legal_opcode, None)
            if func is not None:
                batch_definition = func(op)
                if batch_definition is not None:
                    op.batch_propagable_ids, op.batch_axis = batch_definition
                    if op.batch_axis is not None:
                        op.is_batch_axis_seed = True
                        op.batch_axis_source_ids.append(op.id)
                        equivalent_source_ids[op.id] = {op.id}
        for op in hlo_op_list:
            if op.opcode == 'get-tuple-element':
                if id_to_op[op.operand_ids[0]].opcode == 'batch-norm-training':
                    if op.inst.tuple_index == 0:
                        op.batch_propagable_ids = [op.id, *op.operand_ids]

        def match_rtrcrt(op):
            if op.opcode != 'transpose':
                return None
            out_r_op = id_to_op[op.operand_ids[0]]
            if out_r_op.opcode != 'reshape':
                return None
            convolution_op = id_to_op[out_r_op.operand_ids[0]]
            if convolution_op.opcode != 'convolution':
                return None
            conv_out_non_batch_shape = convolution_op.shape[1:]
            if out_r_op.shape[-len(conv_out_non_batch_shape):] != conv_out_non_batch_shape:
                return None
            in_r_op = id_to_op[convolution_op.operand_ids[0]]
            if in_r_op.opcode != 'reshape':
                return None
            in_t_op = id_to_op[in_r_op.operand_ids[0]]
            if in_t_op.opcode != 'transpose':
                return None
            conv_in_non_batch_shape = in_r_op.shape[1:]
            if in_t_op.shape[-len(conv_in_non_batch_shape):] != conv_in_non_batch_shape:
                return None
            if len(in_t_op.inst.dimensions) != len(op.inst.dimensions):
                return None
            batch_dim = 0
            transposed_batch_dim = list(in_t_op.inst.dimensions).index(batch_dim)
            if op.inst.dimensions[batch_dim] != transposed_batch_dim:
                return None
            r_op = id_to_op[in_t_op.operand_ids[0]]
            if r_op.opcode != 'reshape':
                return None
            if r_op.batch_propagable_ids is None:
                return None
            return r_op, in_t_op, in_r_op, convolution_op, out_r_op, op

        # reshape -> transpose -> reshape -> convolution -> reshape -> transpose pattern
        for op in hlo_op_list:
            rtrcrt = match_rtrcrt(op)
            if rtrcrt is not None and op.batch_propagable_ids is None:
                r_op, *_ = rtrcrt
                op.batch_propagable_ids = [op.id, r_op.id]
                r_op.batch_propagable_ids = [r_op.id, r_op.operand_ids[0], op.id]

        # setup neighboring nodes
        for op in hlo_op_list:
            if op.batch_propagable_ids is not None:
                op.batch_propagable_neighbor_ids.extend(op.batch_propagable_ids)
                propagable_ops = [id_to_op[bpid] for bpid in op.batch_propagable_ids]
                for pop in propagable_ops:
                    if op.id not in pop.batch_propagable_neighbor_ids:
                        pop.batch_propagable_neighbor_ids.append(op.id)

        # propagate batch dimension information by traversing the graph
        source_op_candidates = [op for op in hlo_op_list if op.is_batch_axis_seed]
        source_id_to_root_id = {}
        while source_op_candidates:
            source_op = source_op_candidates.pop()
            visited_ids = set()
            stack = [source_op]
            while stack:
                current_op = stack.pop()
                if current_op.id in visited_ids:
                    continue
                visited_ids.add(current_op.id)
                if current_op.batch_propagable_ids is not None:
                    propagable_ops = [id_to_op[bpid] for bpid in current_op.batch_propagable_ids]
                    for op in propagable_ops:
                        if op.is_batch_axis_seed:
                            equivalent_source_ids[source_op.id].update(op.batch_axis_source_ids)
                        else:
                            op.batch_axis = source_op.batch_axis
                        if source_op.id not in op.batch_axis_source_ids:
                            op.batch_axis_source_ids.append(source_op.id)
                if current_op.legal_opcode != 'tuple':
                    # tuple should be a sink node that doesn't link different outputs together
                    stack.extend(id_to_op[oid] for oid in current_op.batch_propagable_neighbor_ids)

            # remove transitively equivalent source candidates
            visited_ids = set()
            stack = list(equivalent_source_ids.keys())
            while stack:
                current_id = stack.pop()
                if current_id in visited_ids:
                    continue
                if current_id not in source_id_to_root_id:
                    source_id_to_root_id[current_id] = current_id
                root_id = source_id_to_root_id[current_id]
                for equivalent_id in equivalent_source_ids[current_id]:
                    source_id_to_root_id[equivalent_id] = root_id
            root_id = source_id_to_root_id[source_op.id]
            source_op_candidates = []
            for op in source_op_candidates:
                if source_id_to_root_id[op.batch_axis_source_id] != root_id:
                    source_op_candidates.append(op)

        # enable dynamic batch size only if there is a single root source of batch dimension
        parameter_ops = []
        for name in self.hlo_module.host_program_shape.parameter_names:
            parameter_id = self.parameter_name_to_id[name]
            parameter_ops.append(id_to_op[parameter_id])
        output_ops = [id_to_op[oid] for oid in self.output_tuple_op.operand_ids]
        io_ops = parameter_ops + output_ops
        io_ops_with_batch_axis = [op for op in io_ops if op.batch_axis is not None]
        if not io_ops_with_batch_axis:
            return
        io_ops_root_ids = []
        for op in io_ops_with_batch_axis:
            root_ids = {source_id_to_root_id.get(sid, sid) for sid in op.batch_axis_source_ids}
            io_ops_root_ids.append(root_ids)
        if all(len(root_ids) == 1 for root_ids in io_ops_root_ids):
            io_ops_source_id = io_ops_with_batch_axis[-1].batch_axis_source_ids[0]
            io_ops_root_id = source_id_to_root_id.get(io_ops_source_id, io_ops_source_id)
            reject = False
            for op in hlo_op_list:
                root_ids = {source_id_to_root_id.get(sid, sid) for sid in op.batch_axis_source_ids}
                if io_ops_root_id in root_ids and len(root_ids) != 1:
                    # source conflict on instructions affecting IO instructions; reject
                    op.batch_axis = None
                    reject = True
            if reject:
                return
        else:
            # source conflict on IO instructions directly; reject
            return

        # TODO: this will be unnecessary once BatchHloInstructionPool is fully populated
        if any(op.batch_axis is None for op in io_ops):
            return

        # reshape -> transpose -> reshape -> convolution -> reshape -> transpose pattern again
        for op in hlo_op_list:
            rtrcrt = match_rtrcrt(op)
            if rtrcrt is not None:
                if [o.batch_axis for o in rtrcrt] == [0, None, None, 0, None, 0]:
                    r_op, in_t_op, in_r_op, convolution_op, out_r_op, out_t_op = rtrcrt
                    in_t_op.batch_axis = list(in_t_op.inst.dimensions).index(0)
                    conv_in_non_batch_shape = in_r_op.shape[1:]
                    in_t_op_multiplier_shape = in_t_op.shape[:-len(conv_in_non_batch_shape)]
                    in_t_op_multiplier_shape.pop(in_t_op.batch_axis)
                    batch_size_multiplier = int(np.prod(in_t_op_multiplier_shape))
                    in_r_op.batch_axis = 0
                    in_r_op.batch_size_multiplier = batch_size_multiplier
                    convolution_op.batch_axis = 0
                    convolution_op.batch_size_multiplier = batch_size_multiplier
                    out_r_op.batch_axis = out_t_op.inst.dimensions[out_t_op.batch_axis]

        # write input_batch_axis and output_batch_axis in runtime format
        input_batch_axis = [op.batch_axis for op in parameter_ops]
        output_batch_axis = [op.batch_axis for op in output_ops]
        _assert_same_len(self.inputs, input_batch_axis, 'inputs', 'input_batch_axis')
        _assert_same_len(self.outputs, output_batch_axis, 'outputs', 'output_batch_axis')
        for args in [self.inputs, input_batch_axis], [self.outputs, output_batch_axis]:
            for ts, axis in zip(*args):
                ts.batch_axis = axis
        self.hlo_ops_with_batch_axis = []
        for op in hlo_op_list:
            if op.batch_axis is not None:
                if op.opcode not in HloOptimizer.tuple_output_opcodes:
                    self.hlo_ops_with_batch_axis.append(op)
        self.original_batch_size = self.get_batch_size()