mysqlx-connector-python/lib/mysqlx/protocol.py (737 lines of code) (raw):
# Copyright (c) 2016, 2025, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Implementation of the X protocol for MySQL servers."""
import struct
import zlib
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union
try:
import lz4.frame
HAVE_LZ4 = True
except ImportError:
HAVE_LZ4 = False
try:
import zstandard as zstd
HAVE_ZSTD = True
except ImportError:
HAVE_ZSTD = False
from .errors import (
InterfaceError,
NotSupportedError,
OperationalError,
ProgrammingError,
)
from .expr import (
ExprParser,
build_bool_scalar,
build_expr,
build_int_scalar,
build_scalar,
build_unsigned_int_scalar,
)
from .helpers import encode_to_bytes, get_item_or_attr
from .logger import logger
from .protobuf import CRUD_PREPARE_MAPPING, SERVER_MESSAGES, Message, mysqlxpb_enum
from .result import Column
from .statement import (
AddStatement,
DeleteStatement,
FilterableStatement,
FindStatement,
InsertStatement,
ModifyStatement,
ReadStatement,
RemoveStatement,
SqlStatement,
UpdateStatement,
)
from .types import (
ColumnType,
MessageType,
ProtobufMessageCextType,
ProtobufMessageType,
ResultBaseType,
SocketType,
StatementType,
StrOrBytes,
)
_COMPRESSION_THRESHOLD = 1000
class Compressor:
"""Implements compression/decompression using `zstd_stream`, `lz4_message`
and `deflate_stream` algorithms.
Args:
algorithm (str): Compression algorithm.
.. versionadded:: 8.0.21
"""
def __init__(self, algorithm: str) -> None:
self._algorithm: str = algorithm
self._compressobj: Any = None
self._decompressobj: Any = None
if algorithm == "zstd_stream":
self._compressobj = zstd.ZstdCompressor()
self._decompressobj = zstd.ZstdDecompressor()
elif algorithm == "deflate_stream":
self._compressobj = zlib.compressobj()
self._decompressobj = zlib.decompressobj()
def compress(self, data: StrOrBytes) -> bytes:
"""Compresses data and returns it.
Args:
data (str, bytes or buffer object): Data to be compressed.
Returns:
bytes: Compressed data.
"""
if self._algorithm == "zstd_stream":
return self._compressobj.compress(data)
if self._algorithm == "lz4_message":
with lz4.frame.LZ4FrameCompressor() as compressor:
compressed = compressor.begin()
compressed += compressor.compress(data)
compressed += compressor.flush()
return compressed
# Using 'deflate_stream' algorithm
compressed = self._compressobj.compress(data)
compressed += self._compressobj.flush(zlib.Z_SYNC_FLUSH)
return compressed
def decompress(self, data: StrOrBytes) -> bytes:
"""Decompresses a frame of data and returns it as a string of bytes.
Args:
data (str, bytes or buffer object): Data to be compressed.
Returns:
bytes: Decompresssed data.
"""
if self._algorithm == "zstd_stream":
return self._decompressobj.decompress(data)
if self._algorithm == "lz4_message":
with lz4.frame.LZ4FrameDecompressor() as decompressor:
decompressed = decompressor.decompress(data)
return decompressed
# Using 'deflate' algorithm
decompressed = self._decompressobj.decompress(data)
decompressed += self._decompressobj.flush(zlib.Z_SYNC_FLUSH)
return decompressed
class MessageReader:
"""Implements a Message Reader.
Args:
socket_stream (mysqlx.connection.SocketStream): `SocketStream` object.
.. versionadded:: 8.0.21
"""
def __init__(self, socket_stream: SocketType) -> None:
self._stream: SocketType = socket_stream
self._compressor: Optional[Compressor] = None
self._msg: MessageType = None
self._msg_queue: List[Message] = []
def _read_message(self) -> MessageType:
"""Reads X Protocol messages from the stream and returns a
:class:`mysqlx.protobuf.Message` object.
Raises:
:class:`mysqlx.ProgrammingError`: If e connected server does not
have the MySQL X protocol plugin
enabled.
Returns:
mysqlx.protobuf.Message: MySQL X Protobuf Message.
"""
if self._msg_queue:
return self._msg_queue.pop(0)
frame_size, frame_type = struct.unpack("<LB", self._stream.read(5))
if frame_type == 10:
raise ProgrammingError(
"The connected server does not have the "
"MySQL X protocol plugin enabled or "
"protocol mismatch"
)
frame_payload = self._stream.read(frame_size - 1)
if frame_type not in SERVER_MESSAGES:
raise ValueError(f"Unknown message type: {frame_type}")
# Do not parse empty notices, Message requires a type in payload
if frame_type == 11 and frame_payload == b"":
return self._read_message()
frame_msg = Message.from_server_message(frame_type, frame_payload)
if frame_type == 19: # Mysqlx.ServerMessages.Type.COMPRESSION
uncompressed_size = frame_msg["uncompressed_size"]
stream = BytesIO(self._compressor.decompress(frame_msg["payload"]))
bytes_processed = 0
while bytes_processed < uncompressed_size:
payload_size, msg_type = struct.unpack("<LB", stream.read(5))
payload = stream.read(payload_size - 1)
self._msg_queue.append(Message.from_server_message(msg_type, payload))
bytes_processed += payload_size + 4
return self._msg_queue.pop(0) if self._msg_queue else None
return frame_msg
def read_message(self) -> MessageType:
"""Read message.
Returns:
mysqlx.protobuf.Message: MySQL X Protobuf Message.
"""
if self._msg is not None:
msg = self._msg
self._msg = None
return msg
return self._read_message()
def push_message(self, msg: MessageType) -> None:
"""Push message.
Args:
msg (mysqlx.protobuf.Message): MySQL X Protobuf Message.
Raises:
:class:`mysqlx.OperationalError`: If message push slot is full.
"""
if self._msg is not None:
raise OperationalError("Message push slot is full")
self._msg = msg
def set_compression(self, algorithm: str) -> None:
"""Creates a :class:`mysqlx.protocol.Compressor` object based on the
compression algorithm.
Args:
algorithm (str): Compression algorithm.
.. versionadded:: 8.0.21
"""
self._compressor = Compressor(algorithm) if algorithm else None
class MessageWriter:
"""Implements a Message Writer.
Args:
socket_stream (mysqlx.connection.SocketStream): `SocketStream` object.
.. versionadded:: 8.0.21
"""
def __init__(self, socket_stream: SocketType) -> None:
self._stream: SocketType = socket_stream
self._compressor: Optional[Compressor] = None
def write_message(self, msg_type: int, msg: MessageType) -> None:
"""Write message.
Args:
msg_type (int): The message type.
msg (mysqlx.protobuf.Message): MySQL X Protobuf Message.
"""
msg_size = msg.byte_size(msg)
if self._compressor and msg_size > _COMPRESSION_THRESHOLD:
msg_str = encode_to_bytes(msg.serialize_to_string())
header = struct.pack("<LB", msg_size + 1, msg_type)
compressed = self._compressor.compress(b"".join([header, msg_str]))
msg_first_fields = Message("Mysqlx.Connection.Compression")
msg_first_fields["client_messages"] = msg_type
msg_first_fields["uncompressed_size"] = msg_size + 5
msg_payload = Message("Mysqlx.Connection.Compression")
msg_payload["payload"] = compressed
output = b"".join(
[
encode_to_bytes(msg_first_fields.serialize_partial_to_string())[
:-2
],
encode_to_bytes(msg_payload.serialize_partial_to_string()),
]
)
msg_comp_id = mysqlxpb_enum("Mysqlx.ClientMessages.Type.COMPRESSION")
header = struct.pack("<LB", len(output) + 1, msg_comp_id)
self._stream.sendall(b"".join([header, output]))
else:
msg_str = encode_to_bytes(msg.serialize_to_string())
header = struct.pack("<LB", msg_size + 1, msg_type)
self._stream.sendall(b"".join([header, msg_str]))
def set_compression(self, algorithm: str) -> None:
"""Creates a :class:`mysqlx.protocol.Compressor` object based on the
compression algorithm.
Args:
algorithm (str): Compression algorithm.
"""
self._compressor = Compressor(algorithm) if algorithm else None
class Protocol:
"""Implements the MySQL X Protocol.
Args:
read (mysqlx.protocol.MessageReader): A Message Reader object.
writer (mysqlx.protocol.MessageWriter): A Message Writer object.
.. versionchanged:: 8.0.21
"""
def __init__(self, reader: MessageReader, writer: MessageWriter) -> None:
self._reader: MessageReader = reader
self._writer: MessageWriter = writer
self._compression_algorithm: Optional[str] = None
self._warnings: List[str] = []
@property
def compression_algorithm(self) -> Optional[str]:
"""str: The compresion algorithm."""
return self._compression_algorithm
@staticmethod
def _apply_filter(msg: MessageType, stmt: FilterableStatement) -> None:
"""Apply filter.
Args:
msg (mysqlx.protobuf.Message): The MySQL X Protobuf Message.
stmt (Statement): A `Statement` based type object.
"""
if stmt.has_where:
msg["criteria"] = stmt.get_where_expr()
if stmt.has_sort:
msg["order"].extend(stmt.get_sort_expr())
if stmt.has_group_by:
msg["grouping"].extend(stmt.get_grouping())
if stmt.has_having:
msg["grouping_criteria"] = stmt.get_having()
def _create_any(self, arg: Any) -> Optional[MessageType]:
"""Create any.
Args:
arg (object): Arbitrary object.
Returns:
mysqlx.protobuf.Message: MySQL X Protobuf Message.
"""
if isinstance(arg, str):
value = Message("Mysqlx.Datatypes.Scalar.String", value=arg)
scalar = Message("Mysqlx.Datatypes.Scalar", type=8, v_string=value)
return Message("Mysqlx.Datatypes.Any", type=1, scalar=scalar)
if isinstance(arg, bool):
return Message(
"Mysqlx.Datatypes.Any", type=1, scalar=build_bool_scalar(arg)
)
if isinstance(arg, int):
if arg < 0:
return Message(
"Mysqlx.Datatypes.Any",
type=1,
scalar=build_int_scalar(arg),
)
return Message(
"Mysqlx.Datatypes.Any",
type=1,
scalar=build_unsigned_int_scalar(arg),
)
if isinstance(arg, tuple) and len(arg) == 2:
arg_key, arg_value = arg
obj_fld = Message(
"Mysqlx.Datatypes.Object.ObjectField",
key=arg_key,
value=self._create_any(arg_value),
)
obj = Message("Mysqlx.Datatypes.Object", fld=[obj_fld.get_message()])
return Message("Mysqlx.Datatypes.Any", type=2, obj=obj)
if isinstance(arg, dict) or (
isinstance(arg, (list, tuple)) and isinstance(arg[0], dict)
):
array_values = []
for items in arg:
obj_flds = []
for key, value in items.items():
# Array can only handle Any types, Mysqlx.Datatypes.Any.obj
obj_fld = Message(
"Mysqlx.Datatypes.Object.ObjectField",
key=key,
value=self._create_any(value),
)
obj_flds.append(obj_fld.get_message())
msg_obj = Message("Mysqlx.Datatypes.Object", fld=obj_flds)
msg_any = Message("Mysqlx.Datatypes.Any", type=2, obj=msg_obj)
array_values.append(msg_any.get_message())
msg = Message("Mysqlx.Datatypes.Array")
msg["value"] = array_values
return Message("Mysqlx.Datatypes.Any", type=3, array=msg)
if isinstance(arg, list):
obj_flds = []
for key, value in arg:
obj_fld = Message(
"Mysqlx.Datatypes.Object.ObjectField",
key=key,
value=self._create_any(value),
)
obj_flds.append(obj_fld.get_message())
msg_obj = Message("Mysqlx.Datatypes.Object", fld=obj_flds)
msg_any = Message("Mysqlx.Datatypes.Any", type=2, obj=msg_obj)
return msg_any
return None
def _get_binding_args(
self, stmt: Union[FilterableStatement, SqlStatement], is_scalar: bool = True
) -> Union[List[None], List[Union[ProtobufMessageType, ProtobufMessageCextType]]]:
"""Returns the binding any/scalar.
Args:
stmt (Statement): A `Statement` based type object.
is_scalar (bool): `True` to return scalar values.
Raises:
:class:`mysqlx.ProgrammingError`: If unable to find placeholder for
parameter.
Returns:
list: A list of ``Any`` or ``Scalar`` objects.
"""
def build_value(
value: Any,
) -> Union[ProtobufMessageType, ProtobufMessageCextType]:
if is_scalar:
return build_scalar(value).get_message()
return self._create_any(value).get_message()
bindings = stmt.get_bindings()
binding_map = stmt.get_binding_map()
# If binding_map is None it's a SqlStatement object
if binding_map is None:
return [build_value(value) for value in bindings]
count = len(binding_map)
args: List[Any] = count * [None]
if count != len(bindings):
raise ProgrammingError(
"The number of bind parameters and placeholders do not match"
)
for name, value in bindings.items(): # type: ignore[union-attr]
if name not in binding_map:
raise ProgrammingError(
f"Unable to find placeholder for parameter: {name}"
)
pos = binding_map[name]
args[pos] = build_value(value)
return args
def _process_frame(self, msg: MessageType, result: ResultBaseType) -> None:
"""Process frame.
Args:
msg (mysqlx.protobuf.Message): A MySQL X Protobuf Message.
result (Result): A `Result` based type object.
"""
if msg["type"] == 1:
warn_msg = Message.from_message("Mysqlx.Notice.Warning", msg["payload"])
self._warnings.append(warn_msg.msg)
logger.warning(
"Protocol.process_frame Received Warning Notice code %s: %s",
warn_msg.code,
warn_msg.msg,
)
result.append_warning(warn_msg.level, warn_msg.code, warn_msg.msg)
elif msg["type"] == 2:
Message.from_message("Mysqlx.Notice.SessionVariableChanged", msg["payload"])
elif msg["type"] == 3:
sess_state_msg = Message.from_message(
"Mysqlx.Notice.SessionStateChanged", msg["payload"]
)
if sess_state_msg["param"] == mysqlxpb_enum(
"Mysqlx.Notice.SessionStateChanged.Parameter.GENERATED_DOCUMENT_IDS"
):
result.set_generated_ids(
[
get_item_or_attr(
get_item_or_attr(value, "v_octets"), "value"
).decode()
for value in sess_state_msg["value"]
]
)
else: # Following results are unitary and not a list
sess_state_value = sess_state_msg["value"].pop()
if sess_state_msg["param"] == mysqlxpb_enum(
"Mysqlx.Notice.SessionStateChanged.Parameter.ROWS_AFFECTED"
):
result.set_rows_affected(
get_item_or_attr(sess_state_value, "v_unsigned_int")
)
elif sess_state_msg["param"] == mysqlxpb_enum(
"Mysqlx.Notice.SessionStateChanged.Parameter.GENERATED_INSERT_ID"
):
result.set_generated_insert_id(
get_item_or_attr(sess_state_value, "v_unsigned_int")
)
def _read_message(self, result: ResultBaseType) -> Optional[MessageType]:
"""Read message.
Args:
result (Result): A `Result` based type object.
"""
while True:
try:
msg = self._reader.read_message()
except RuntimeError as err:
warnings = repr(result.get_warnings())
if warnings:
raise RuntimeError(f"{err} reason: {warnings}") from err
if msg.type == "Mysqlx.Error":
raise OperationalError(msg["msg"], msg["code"])
if msg.type == "Mysqlx.Notice.Frame":
try:
self._process_frame(msg, result)
except (AttributeError, KeyError):
continue
elif msg.type == "Mysqlx.Sql.StmtExecuteOk":
return None
elif msg.type == "Mysqlx.Resultset.FetchDone":
result.set_closed(True)
elif msg.type == "Mysqlx.Resultset.FetchDoneMoreResultsets":
result.set_has_more_results(True)
elif msg.type == "Mysqlx.Resultset.Row":
result.set_has_data(True)
break
else:
break
return msg
def set_compression(self, algorithm: str) -> None:
"""Sets the compression algorithm to be used by the compression
object, for uplink and downlink.
Args:
algorithm (str): Algorithm to be used in compression/decompression.
.. versionadded:: 8.0.21
"""
self._compression_algorithm = algorithm
self._reader.set_compression(algorithm)
self._writer.set_compression(algorithm)
def get_capabilites(self) -> MessageType:
"""Get capabilities.
Returns:
mysqlx.protobuf.Message: MySQL X Protobuf Message.
"""
msg = Message("Mysqlx.Connection.CapabilitiesGet")
self._writer.write_message(
mysqlxpb_enum("Mysqlx.ClientMessages.Type.CON_CAPABILITIES_GET"),
msg,
)
msg = self._reader.read_message()
while msg.type == "Mysqlx.Notice.Frame":
msg = self._reader.read_message()
if msg.type == "Mysqlx.Error":
raise OperationalError(msg["msg"], msg["code"])
return msg
def set_capabilities(self, **kwargs: Any) -> None:
"""Set capabilities.
Args:
**kwargs: Arbitrary keyword arguments.
Returns:
mysqlx.protobuf.Message: MySQL X Protobuf Message.
"""
if not kwargs:
return None
capabilities = Message("Mysqlx.Connection.Capabilities")
for key, value in kwargs.items():
capability = Message("Mysqlx.Connection.Capability")
capability["name"] = key
if isinstance(value, dict):
items = value
obj_flds = []
for item in items:
obj_fld = Message(
"Mysqlx.Datatypes.Object.ObjectField",
key=item,
value=self._create_any(items[item]),
)
obj_flds.append(obj_fld.get_message())
msg_obj = Message("Mysqlx.Datatypes.Object", fld=obj_flds)
msg_any = Message("Mysqlx.Datatypes.Any", type=2, obj=msg_obj)
capability["value"] = msg_any.get_message()
else:
capability["value"] = self._create_any(value)
capabilities["capabilities"].extend([capability.get_message()])
msg = Message("Mysqlx.Connection.CapabilitiesSet")
msg["capabilities"] = capabilities
self._writer.write_message(
mysqlxpb_enum("Mysqlx.ClientMessages.Type.CON_CAPABILITIES_SET"),
msg,
)
try:
return self.read_ok()
except InterfaceError as err:
# Skip capability "session_connect_attrs" error since
# is only available on version >= 8.0.16
if err.errno != 5002:
raise
return None
def send_auth_start(
self,
method: str,
auth_data: Optional[str] = None,
initial_response: Optional[str] = None,
) -> None:
"""Send authenticate start.
Args:
method (str): Message method.
auth_data (Optional[str]): Authentication data.
initial_response (Optional[str]): Initial response.
"""
msg = Message("Mysqlx.Session.AuthenticateStart")
msg["mech_name"] = method
if auth_data is not None:
msg["auth_data"] = auth_data
if initial_response is not None:
msg["initial_response"] = initial_response
self._writer.write_message(
mysqlxpb_enum("Mysqlx.ClientMessages.Type.SESS_AUTHENTICATE_START"),
msg,
)
def read_auth_continue(self) -> bytes:
"""Read authenticate continue.
Raises:
:class:`InterfaceError`: If the message type is not
`Mysqlx.Session.AuthenticateContinue`
Returns:
str: The authentication data.
"""
msg = self._reader.read_message()
while msg.type == "Mysqlx.Notice.Frame":
msg = self._reader.read_message()
if msg.type != "Mysqlx.Session.AuthenticateContinue":
raise InterfaceError(
"Unexpected message encountered during authentication handshake"
)
return msg["auth_data"]
def send_auth_continue(self, auth_data: str) -> None:
"""Send authenticate continue.
Args:
auth_data (str): Authentication data.
"""
msg = Message("Mysqlx.Session.AuthenticateContinue", auth_data=auth_data)
self._writer.write_message(
mysqlxpb_enum("Mysqlx.ClientMessages.Type.SESS_AUTHENTICATE_CONTINUE"),
msg,
)
def read_auth_ok(self) -> None:
"""Read authenticate OK.
Raises:
:class:`mysqlx.InterfaceError`: If message type is `Mysqlx.Error`.
"""
while True:
msg = self._reader.read_message()
if msg.type == "Mysqlx.Session.AuthenticateOk":
break
if msg.type == "Mysqlx.Error":
raise InterfaceError(msg.msg)
def send_prepare_prepare(
self,
msg_type: str,
msg: MessageType,
stmt: Union[
FindStatement,
DeleteStatement,
ModifyStatement,
ReadStatement,
RemoveStatement,
UpdateStatement,
],
) -> None:
"""
Send prepare statement.
Args:
msg_type (str): Message ID string.
msg (mysqlx.protobuf.Message): MySQL X Protobuf Message.
stmt (Statement): A `Statement` based type object.
Raises:
:class:`mysqlx.NotSupportedError`: If prepared statements are not
supported.
.. versionadded:: 8.0.16
"""
if stmt.has_limit and msg.type != "Mysqlx.Crud.Insert":
# Remove 'limit' from message by building a new one
if msg.type == "Mysqlx.Crud.Find":
_, msg = self.build_find(stmt) # type: ignore[arg-type]
elif msg.type == "Mysqlx.Crud.Update":
_, msg = self.build_update(stmt) # type: ignore[arg-type]
elif msg.type == "Mysqlx.Crud.Delete":
_, msg = self.build_delete(stmt) # type: ignore[arg-type]
else:
raise ValueError(f"Invalid message type: {msg_type}")
# Build 'limit_expr' message
position = len(stmt.get_bindings())
placeholder = mysqlxpb_enum("Mysqlx.Expr.Expr.Type.PLACEHOLDER")
msg_limit_expr = Message("Mysqlx.Crud.LimitExpr")
msg_limit_expr["row_count"] = Message(
"Mysqlx.Expr.Expr", type=placeholder, position=position
)
if msg.type == "Mysqlx.Crud.Find":
msg_limit_expr["offset"] = Message(
"Mysqlx.Expr.Expr", type=placeholder, position=position + 1
)
msg["limit_expr"] = msg_limit_expr
oneof_type, oneof_op = CRUD_PREPARE_MAPPING[msg_type]
msg_oneof = Message("Mysqlx.Prepare.Prepare.OneOfMessage")
msg_oneof["type"] = mysqlxpb_enum(oneof_type)
msg_oneof[oneof_op] = msg
msg_prepare = Message("Mysqlx.Prepare.Prepare")
msg_prepare["stmt_id"] = stmt.stmt_id
msg_prepare["stmt"] = msg_oneof
self._writer.write_message(
mysqlxpb_enum("Mysqlx.ClientMessages.Type.PREPARE_PREPARE"),
msg_prepare,
)
try:
self.read_ok()
except InterfaceError as err:
raise NotSupportedError from err
def send_prepare_execute(
self, msg_type: str, msg: MessageType, stmt: FilterableStatement
) -> None:
"""
Send execute statement.
Args:
msg_type (str): Message ID string.
msg (mysqlx.protobuf.Message): MySQL X Protobuf Message.
stmt (Statement): A `Statement` based type object.
.. versionadded:: 8.0.16
"""
oneof_type, oneof_op = CRUD_PREPARE_MAPPING[msg_type]
msg_oneof = Message("Mysqlx.Prepare.Prepare.OneOfMessage")
msg_oneof["type"] = mysqlxpb_enum(oneof_type)
msg_oneof[oneof_op] = msg
msg_execute = Message("Mysqlx.Prepare.Execute")
msg_execute["stmt_id"] = stmt.stmt_id
args = self._get_binding_args(stmt, is_scalar=False)
if args:
msg_execute["args"].extend(args)
if stmt.has_limit:
msg_execute["args"].extend(
[
self._create_any(stmt.get_limit_row_count()).get_message(),
self._create_any(stmt.get_limit_offset()).get_message(),
]
)
self._writer.write_message(
mysqlxpb_enum("Mysqlx.ClientMessages.Type.PREPARE_EXECUTE"),
msg_execute,
)
def send_prepare_deallocate(self, stmt_id: int) -> None:
"""
Send prepare deallocate statement.
Args:
stmt_id (int): Statement ID.
.. versionadded:: 8.0.16
"""
msg_dealloc = Message("Mysqlx.Prepare.Deallocate")
msg_dealloc["stmt_id"] = stmt_id
self._writer.write_message(
mysqlxpb_enum("Mysqlx.ClientMessages.Type.PREPARE_DEALLOCATE"),
msg_dealloc,
)
self.read_ok()
def send_msg_without_ps(
self,
msg_type: str,
msg: MessageType,
stmt: Union[FilterableStatement, SqlStatement],
) -> None:
"""
Send a message without prepared statements support.
Args:
msg_type (str): Message ID string.
msg (mysqlx.protobuf.Message): MySQL X Protobuf Message.
stmt (Statement): A `Statement` based type object.
.. versionadded:: 8.0.16
"""
if stmt.has_limit:
msg_limit = Message("Mysqlx.Crud.Limit")
msg_limit["row_count"] = stmt.get_limit_row_count() # type: ignore[union-attr]
if msg.type == "Mysqlx.Crud.Find":
msg_limit["offset"] = stmt.get_limit_offset() # type: ignore[union-attr]
msg["limit"] = msg_limit
is_scalar = msg_type != "Mysqlx.ClientMessages.Type.SQL_STMT_EXECUTE"
args = self._get_binding_args(stmt, is_scalar=is_scalar)
if args:
msg["args"].extend(args)
self.send_msg(msg_type, msg)
def send_msg(self, msg_type: str, msg: MessageType) -> None:
"""
Send a message.
Args:
msg_type (str): Message ID string.
msg (mysqlx.protobuf.Message): MySQL X Protobuf Message.
.. versionadded:: 8.0.16
"""
self._writer.write_message(mysqlxpb_enum(msg_type), msg)
def build_find(
self, stmt: Union[FindStatement, ReadStatement]
) -> Tuple[str, MessageType]:
"""Build find/read message.
Args:
stmt (Statement): A :class:`mysqlx.ReadStatement` or
:class:`mysqlx.FindStatement` object.
Returns:
(tuple): Tuple containing:
* `str`: Message ID string.
* :class:`mysqlx.protobuf.Message`: MySQL X Protobuf Message.
.. versionadded:: 8.0.16
"""
data_model = mysqlxpb_enum(
"Mysqlx.Crud.DataModel.DOCUMENT"
if stmt.is_doc_based()
else "Mysqlx.Crud.DataModel.TABLE"
)
collection = Message(
"Mysqlx.Crud.Collection",
name=stmt.target.name,
schema=stmt.schema.name,
)
msg = Message("Mysqlx.Crud.Find", data_model=data_model, collection=collection)
if stmt.has_projection:
msg["projection"] = stmt.get_projection_expr()
self._apply_filter(msg, stmt)
if stmt.is_lock_exclusive():
msg["locking"] = mysqlxpb_enum("Mysqlx.Crud.Find.RowLock.EXCLUSIVE_LOCK")
elif stmt.is_lock_shared():
msg["locking"] = mysqlxpb_enum("Mysqlx.Crud.Find.RowLock.SHARED_LOCK")
if stmt.lock_contention.value > 0:
msg["locking_options"] = stmt.lock_contention.value
return "Mysqlx.ClientMessages.Type.CRUD_FIND", msg
def build_update(
self, stmt: Union[ModifyStatement, UpdateStatement]
) -> Tuple[str, MessageType]:
"""Build update message.
Args:
stmt (Statement): A :class:`mysqlx.ModifyStatement` or
:class:`mysqlx.UpdateStatement` object.
Returns:
(tuple): Tuple containing:
* `str`: Message ID string.
* :class:`mysqlx.protobuf.Message`: MySQL X Protobuf Message.
.. versionadded:: 8.0.16
"""
data_model = mysqlxpb_enum(
"Mysqlx.Crud.DataModel.DOCUMENT"
if stmt.is_doc_based()
else "Mysqlx.Crud.DataModel.TABLE"
)
collection = Message(
"Mysqlx.Crud.Collection",
name=stmt.target.name,
schema=stmt.schema.name,
)
msg = Message(
"Mysqlx.Crud.Update", data_model=data_model, collection=collection
)
self._apply_filter(msg, stmt)
for _, update_op in stmt.get_update_ops().items():
operation = Message("Mysqlx.Crud.UpdateOperation")
operation["operation"] = update_op.update_type
operation["source"] = update_op.source
if update_op.value is not None:
operation["value"] = build_expr(update_op.value)
msg["operation"].extend([operation.get_message()])
return "Mysqlx.ClientMessages.Type.CRUD_UPDATE", msg
def build_delete(
self, stmt: Union[DeleteStatement, RemoveStatement]
) -> Tuple[str, MessageType]:
"""Build delete message.
Args:
stmt (Statement): A :class:`mysqlx.DeleteStatement` or
:class:`mysqlx.RemoveStatement` object.
Returns:
(tuple): Tuple containing:
* `str`: Message ID string.
* :class:`mysqlx.protobuf.Message`: MySQL X Protobuf Message.
.. versionadded:: 8.0.16
"""
data_model = mysqlxpb_enum(
"Mysqlx.Crud.DataModel.DOCUMENT"
if stmt.is_doc_based()
else "Mysqlx.Crud.DataModel.TABLE"
)
collection = Message(
"Mysqlx.Crud.Collection",
name=stmt.target.name,
schema=stmt.schema.name,
)
msg = Message(
"Mysqlx.Crud.Delete", data_model=data_model, collection=collection
)
self._apply_filter(msg, stmt)
return "Mysqlx.ClientMessages.Type.CRUD_DELETE", msg
def build_execute_statement(
self,
namespace: str,
stmt: Union[str, StatementType],
fields: Optional[Dict[str, Any]] = None,
) -> Tuple[str, MessageType]:
"""Build execute statement.
Args:
namespace (str): The namespace.
stmt (Statement): A `Statement` based type object.
fields (Optional[dict]): The message fields.
Returns:
(tuple): Tuple containing:
* `str`: Message ID string.
* :class:`mysqlx.protobuf.Message`: MySQL X Protobuf Message.
.. versionadded:: 8.0.16
"""
msg = Message(
"Mysqlx.Sql.StmtExecute",
namespace=namespace,
stmt=stmt,
compact_metadata=False,
)
if fields:
obj_flds = []
for key, value in fields.items():
obj_fld = Message(
"Mysqlx.Datatypes.Object.ObjectField",
key=key,
value=self._create_any(value),
)
obj_flds.append(obj_fld.get_message())
msg_obj = Message("Mysqlx.Datatypes.Object", fld=obj_flds)
msg_any = Message("Mysqlx.Datatypes.Any", type=2, obj=msg_obj)
msg["args"] = [msg_any.get_message()]
return "Mysqlx.ClientMessages.Type.SQL_STMT_EXECUTE", msg
@staticmethod
def build_insert(
stmt: Union[AddStatement, InsertStatement],
) -> Tuple[str, MessageType]:
"""Build insert statement.
Args:
stmt (Statement): A :class:`mysqlx.AddStatement` or
:class:`mysqlx.InsertStatement` object.
Returns:
(tuple): Tuple containing:
* `str`: Message ID string.
* :class:`mysqlx.protobuf.Message`: MySQL X Protobuf Message.
.. versionadded:: 8.0.16
"""
data_model = mysqlxpb_enum(
"Mysqlx.Crud.DataModel.DOCUMENT"
if stmt.is_doc_based()
else "Mysqlx.Crud.DataModel.TABLE"
)
collection = Message(
"Mysqlx.Crud.Collection",
name=stmt.target.name,
schema=stmt.schema.name,
)
msg = Message(
"Mysqlx.Crud.Insert", data_model=data_model, collection=collection
)
if hasattr(stmt, "_fields"):
for field in stmt._fields:
expr = ExprParser(
field, not stmt.is_doc_based()
).parse_table_insert_field()
msg["projection"].extend([expr.get_message()])
for value in stmt.get_values():
row = Message("Mysqlx.Crud.Insert.TypedRow")
if isinstance(value, list):
for val in value:
row["field"].extend([build_expr(val).get_message()])
else:
row["field"].extend([build_expr(value).get_message()])
msg["row"].extend([row.get_message()])
if hasattr(stmt, "is_upsert"):
msg["upsert"] = stmt.is_upsert()
return "Mysqlx.ClientMessages.Type.CRUD_INSERT", msg
def close_result(self, result: ResultBaseType) -> None:
"""Close the result.
Args:
result (Result): A `Result` based type object.
Raises:
:class:`mysqlx.OperationalError`: If message read is None.
"""
msg = self._read_message(result)
if msg is not None:
raise OperationalError("Expected to close the result")
def read_row(self, result: ResultBaseType) -> Optional[MessageType]:
"""Read row.
Args:
result (Result): A `Result` based type object.
"""
msg = self._read_message(result)
if msg is None:
return None
if msg.type == "Mysqlx.Resultset.Row":
return msg
self._reader.push_message(msg)
return None
def get_column_metadata(self, result: ResultBaseType) -> List[ColumnType]:
"""Returns column metadata.
Args:
result (Result): A `Result` based type object.
Raises:
:class:`mysqlx.InterfaceError`: If unexpected message.
"""
columns = []
while True:
msg = self._read_message(result)
if msg is None:
break
if msg.type == "Mysqlx.Resultset.Row":
self._reader.push_message(msg)
break
if msg.type != "Mysqlx.Resultset.ColumnMetaData":
raise InterfaceError("Unexpected msg type")
col = Column(
msg["type"],
msg["catalog"],
msg["schema"],
msg["table"],
msg["original_table"],
msg["name"],
msg["original_name"],
msg.get("length", 21),
msg.get("collation", 0),
msg.get("fractional_digits", 0),
msg.get("flags", 16),
msg.get("content_type"),
)
columns.append(col)
return columns
def read_ok(self) -> None:
"""Read OK.
Raises:
:class:`mysqlx.InterfaceError`: If unexpected message.
"""
msg = self._reader.read_message()
if msg.type == "Mysqlx.Error":
raise InterfaceError(f"Mysqlx.Error: {msg['msg']}", errno=msg["code"])
if msg.type != "Mysqlx.Ok":
raise InterfaceError("Unexpected message encountered")
def send_connection_close(self) -> None:
"""Send connection close."""
msg = Message("Mysqlx.Connection.Close")
self._writer.write_message(
mysqlxpb_enum("Mysqlx.ClientMessages.Type.CON_CLOSE"), msg
)
def send_close(self) -> None:
"""Send close."""
msg = Message("Mysqlx.Session.Close")
self._writer.write_message(
mysqlxpb_enum("Mysqlx.ClientMessages.Type.SESS_CLOSE"), msg
)
def send_expect_open(self) -> None:
"""Send expectation."""
cond_key = mysqlxpb_enum("Mysqlx.Expect.Open.Condition.Key.EXPECT_FIELD_EXIST")
msg_oc = Message("Mysqlx.Expect.Open.Condition")
msg_oc["condition_key"] = cond_key
msg_oc["condition_value"] = "6.1"
msg_eo = Message("Mysqlx.Expect.Open")
msg_eo["cond"] = [msg_oc.get_message()]
self._writer.write_message(
mysqlxpb_enum("Mysqlx.ClientMessages.Type.EXPECT_OPEN"), msg_eo
)
def send_reset(self, keep_open: Optional[bool] = None) -> bool:
"""Send reset session message.
Returns:
boolean: ``True`` if the server will keep the session open,
otherwise ``False``.
"""
msg = Message("Mysqlx.Session.Reset")
if keep_open is None:
try:
# Send expectation: keep connection open
self.send_expect_open()
self.read_ok()
keep_open = True
except InterfaceError:
# Expectation is unkown by this version of the server
keep_open = False
if keep_open:
msg["keep_open"] = True
self._writer.write_message(
mysqlxpb_enum("Mysqlx.ClientMessages.Type.SESS_RESET"), msg
)
self.read_ok()
if keep_open:
return True
return False