mysqlx-connector-python/lib/mysqlx/statement.py (662 lines of code) (raw):
# Copyright (c) 2016, 2024, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
# mypy: disable-error-code="return-value"
"""Implementation of Statements."""
from __future__ import annotations
import copy
import json
import warnings
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from .constants import LockContention
from .dbdoc import DbDoc
from .errors import NotSupportedError, ProgrammingError
from .expr import ExprParser
from .helpers import deprecated
from .protobuf import mysqlxpb_enum
from .result import DocResult, Result, RowResult, SqlResult
from .types import (
ConnectionType,
DatabaseTargetType,
MessageType,
ProtobufMessageCextType,
ProtobufMessageType,
SchemaType,
)
ERR_INVALID_INDEX_NAME = 'The given index name "{}" is not valid'
class Expr:
"""Expression wrapper."""
def __init__(self, expr: Any) -> None:
self.expr: Any = expr
def flexible_params(*values: Any) -> Union[List, Tuple]:
"""Parse flexible parameters."""
if len(values) == 1 and isinstance(values[0], (list, tuple)):
return values[0]
return values
def is_quoted_identifier(identifier: str, sql_mode: str = "") -> bool:
"""Check if the given identifier is quoted.
Args:
identifier (string): Identifier to check.
sql_mode (Optional[string]): SQL mode.
Returns:
`True` if the identifier has backtick quotes, and False otherwise.
"""
if "ANSI_QUOTES" in sql_mode:
return (identifier[0] == "`" and identifier[-1] == "`") or (
identifier[0] == '"' and identifier[-1] == '"'
)
return identifier[0] == "`" and identifier[-1] == "`"
def quote_identifier(identifier: str, sql_mode: str = "") -> str:
"""Quote the given identifier with backticks, converting backticks (`) in
the identifier name with the correct escape sequence (``).
Args:
identifier (string): Identifier to quote.
sql_mode (Optional[string]): SQL mode.
Returns:
A string with the identifier quoted with backticks.
"""
if len(identifier) == 0:
return "``"
if "ANSI_QUOTES" in sql_mode:
quoted = identifier.replace('"', '""')
return f'"{quoted}"'
quoted = identifier.replace("`", "``")
return f"`{quoted}`"
def quote_multipart_identifier(identifiers: Iterable[str], sql_mode: str = "") -> str:
"""Quote the given multi-part identifier with backticks.
Args:
identifiers (iterable): List of identifiers to quote.
sql_mode (Optional[string]): SQL mode.
Returns:
A string with the multi-part identifier quoted with backticks.
"""
return ".".join(
[quote_identifier(identifier, sql_mode) for identifier in identifiers]
)
def parse_table_name(
default_schema: str, table_name: str, sql_mode: str = ""
) -> Tuple[str, str]:
"""Parse table name.
Args:
default_schema (str): The default schema.
table_name (str): The table name.
sql_mode(Optional[str]): The SQL mode.
Returns:
str: The parsed table name.
"""
quote = '"' if "ANSI_QUOTES" in sql_mode else "`"
delimiter = f".{quote}" if quote in table_name else "."
temp = table_name.split(delimiter, 1)
return (
default_schema if len(temp) == 1 else temp[0].strip(quote),
temp[-1].strip(quote),
)
class Statement:
"""Provides base functionality for statement objects.
Args:
target (object): The target database object, it can be
:class:`mysqlx.Collection` or :class:`mysqlx.Table`.
doc_based (bool): `True` if it is document based.
"""
def __init__(self, target: DatabaseTargetType, doc_based: bool = True) -> None:
self._target: DatabaseTargetType = target
self._doc_based: bool = doc_based
self._connection: Optional[ConnectionType] = (
target.get_connection() if target else None
)
self._stmt_id: Optional[int] = None
self._exec_counter: int = 0
self._changed: bool = True
self._prepared: bool = False
self._deallocate_prepare_execute: bool = False
@property
def target(self) -> DatabaseTargetType:
"""object: The database object target."""
return self._target
@property
def schema(self) -> SchemaType:
""":class:`mysqlx.Schema`: The Schema object."""
return self._target.schema
@property
def stmt_id(self) -> int:
"""Returns this statement ID.
Returns:
int: The statement ID.
"""
return self._stmt_id
@stmt_id.setter
def stmt_id(self, value: int) -> None:
self._stmt_id = value
@property
def exec_counter(self) -> int:
"""int: The number of times this statement was executed."""
return self._exec_counter
@property
def changed(self) -> bool:
"""bool: `True` if this statement has changes."""
return self._changed
@changed.setter
def changed(self, value: bool) -> None:
self._changed = value
@property
def prepared(self) -> bool:
"""bool: `True` if this statement has been prepared."""
return self._prepared
@prepared.setter
def prepared(self, value: bool) -> None:
self._prepared = value
@property
def repeated(self) -> bool:
"""bool: `True` if this statement was executed more than once."""
return self._exec_counter > 1
@property
def deallocate_prepare_execute(self) -> bool:
"""bool: `True` to deallocate + prepare + execute statement."""
return self._deallocate_prepare_execute
@deallocate_prepare_execute.setter
def deallocate_prepare_execute(self, value: bool) -> None:
self._deallocate_prepare_execute = value
def is_doc_based(self) -> bool:
"""Check if it is document based.
Returns:
bool: `True` if it is document based.
"""
return self._doc_based
def increment_exec_counter(self) -> None:
"""Increments the number of times this statement has been executed."""
self._exec_counter += 1
def reset_exec_counter(self) -> None:
"""Resets the number of times this statement has been executed."""
self._exec_counter = 0
def execute(self) -> Any:
"""Execute the statement.
Raises:
NotImplementedError: This method must be implemented.
"""
raise NotImplementedError
class FilterableStatement(Statement):
"""A statement to be used with filterable statements.
Args:
target (object): The target database object, it can be
:class:`mysqlx.Collection` or :class:`mysqlx.Table`.
doc_based (Optional[bool]): `True` if it is document based
(default: `True`).
condition (Optional[str]): Sets the search condition to filter
documents or records.
"""
def __init__(
self,
target: DatabaseTargetType,
doc_based: bool = True,
condition: Optional[str] = None,
) -> None:
super().__init__(target=target, doc_based=doc_based)
self._binding_map: Dict[str, Any] = {}
self._bindings: Union[Dict[str, Any], List] = {}
self._having: Optional[MessageType] = None
self._grouping_str: str = ""
self._grouping: Optional[
List[Union[ProtobufMessageType, ProtobufMessageCextType]]
] = None
self._limit_offset: int = 0
self._limit_row_count: int = None
self._projection_str: str = ""
self._projection_expr: Optional[
List[Union[ProtobufMessageType, ProtobufMessageCextType]]
] = None
self._sort_str: str = ""
self._sort_expr: Optional[
List[Union[ProtobufMessageType, ProtobufMessageCextType]]
] = None
self._where_str: str = ""
self._where_expr: MessageType = None
self.has_bindings: bool = False
self.has_limit: bool = False
self.has_group_by: bool = False
self.has_having: bool = False
self.has_projection: bool = False
self.has_sort: bool = False
self.has_where: bool = False
if condition:
self._set_where(condition)
def _bind_single(self, obj: Union[DbDoc, Dict[str, Any], str]) -> None:
"""Bind single object.
Args:
obj (:class:`mysqlx.DbDoc` or str): DbDoc or JSON string object.
Raises:
:class:`mysqlx.ProgrammingError`: If invalid JSON string to bind.
ValueError: If JSON loaded is not a dictionary.
"""
if isinstance(obj, dict):
self.bind(DbDoc(obj).as_str())
elif isinstance(obj, DbDoc):
self.bind(obj.as_str())
elif isinstance(obj, str):
try:
res = json.loads(obj)
if not isinstance(res, dict):
raise ValueError
except ValueError as err:
raise ProgrammingError("Invalid JSON string to bind") from err
for key in res.keys():
self.bind(key, res[key])
else:
raise ProgrammingError("Invalid JSON string or object to bind")
def _sort(self, *clauses: str) -> FilterableStatement:
"""Sets the sorting criteria.
Args:
*clauses: The expression strings defining the sort criteria.
Returns:
mysqlx.FilterableStatement: FilterableStatement object.
"""
self.has_sort = True
self._sort_str = ",".join(flexible_params(*clauses))
self._sort_expr = ExprParser(
self._sort_str, not self._doc_based
).parse_order_spec()
self._changed = True
return self
def _set_where(self, condition: str) -> FilterableStatement:
"""Sets the search condition to filter.
Args:
condition (str): Sets the search condition to filter documents or
records.
Returns:
mysqlx.FilterableStatement: FilterableStatement object.
"""
self.has_where = True
self._where_str = condition
try:
expr = ExprParser(condition, not self._doc_based)
self._where_expr = expr.expr()
except ValueError as err:
raise ProgrammingError("Invalid condition") from err
self._binding_map = expr.placeholder_name_to_position
self._changed = True
return self
def _set_group_by(self, *fields: str) -> None:
"""Set group by.
Args:
*fields: List of fields.
"""
fields = flexible_params(*fields)
self.has_group_by = True
self._grouping_str = ",".join(fields)
self._grouping = ExprParser(
self._grouping_str, not self._doc_based
).parse_expr_list()
self._changed = True
def _set_having(self, condition: str) -> None:
"""Set having.
Args:
condition (str): The condition.
"""
self.has_having = True
self._having = ExprParser(condition, not self._doc_based).expr()
self._changed = True
def _set_projection(self, *fields: str) -> FilterableStatement:
"""Set the projection.
Args:
*fields: List of fields.
Returns:
:class:`mysqlx.FilterableStatement`: Returns self.
"""
fields = flexible_params(*fields)
self.has_projection = True
self._projection_str = ",".join(fields)
self._projection_expr = ExprParser(
self._projection_str, not self._doc_based
).parse_table_select_projection()
self._changed = True
return self
def get_binding_map(self) -> Dict[str, Any]:
"""Returns the binding map dictionary.
Returns:
dict: The binding map dictionary.
"""
return self._binding_map
def get_bindings(self) -> Union[Dict[str, Any], List]:
"""Returns the bindings list.
Returns:
`list`: The bindings list.
"""
return self._bindings
def get_grouping(self) -> List[Union[ProtobufMessageType, ProtobufMessageCextType]]:
"""Returns the grouping expression list.
Returns:
`list`: The grouping expression list.
"""
return self._grouping
def get_having(self) -> MessageType:
"""Returns the having expression.
Returns:
object: The having expression.
"""
return self._having
def get_limit_row_count(self) -> int:
"""Returns the limit row count.
Returns:
int: The limit row count.
"""
return self._limit_row_count
def get_limit_offset(self) -> int:
"""Returns the limit offset.
Returns:
int: The limit offset.
"""
return self._limit_offset
def get_where_expr(self) -> MessageType:
"""Returns the where expression.
Returns:
object: The where expression.
"""
return self._where_expr
def get_projection_expr(
self,
) -> List[Union[ProtobufMessageType, ProtobufMessageCextType]]:
"""Returns the projection expression.
Returns:
object: The projection expression.
"""
return self._projection_expr
def get_sort_expr(
self,
) -> List[Union[ProtobufMessageType, ProtobufMessageCextType]]:
"""Returns the sort expression.
Returns:
object: The sort expression.
"""
return self._sort_expr
@deprecated("8.0.12")
def where(self, condition: str) -> FilterableStatement:
"""Sets the search condition to filter.
Args:
condition (str): Sets the search condition to filter documents or
records.
Returns:
mysqlx.FilterableStatement: FilterableStatement object.
.. deprecated:: 8.0.12
"""
return self._set_where(condition)
@deprecated("8.0.12")
def sort(self, *clauses: str) -> FilterableStatement:
"""Sets the sorting criteria.
Args:
*clauses: The expression strings defining the sort criteria.
Returns:
mysqlx.FilterableStatement: FilterableStatement object.
.. deprecated:: 8.0.12
"""
return self._sort(*clauses)
def limit(
self, row_count: int, offset: Optional[int] = None
) -> FilterableStatement:
"""Sets the maximum number of items to be returned.
Args:
row_count (int): The maximum number of items.
Returns:
mysqlx.FilterableStatement: FilterableStatement object.
Raises:
ValueError: If ``row_count`` is not a positive integer.
.. versionchanged:: 8.0.12
The usage of ``offset`` was deprecated.
"""
if not isinstance(row_count, int) or row_count < 0:
raise ValueError("The 'row_count' value must be a positive integer")
if not self.has_limit:
self._changed = bool(self._exec_counter == 0)
self._deallocate_prepare_execute = bool(not self._exec_counter == 0)
self._limit_row_count = row_count
self.has_limit = True
if offset:
self.offset(offset)
warnings.warn(
"'limit(row_count, offset)' is deprecated, please "
"use 'offset(offset)' to set the number of items to "
"skip",
category=DeprecationWarning,
)
return self
def offset(self, offset: int) -> FilterableStatement:
"""Sets the number of items to skip.
Args:
offset (int): The number of items to skip.
Returns:
mysqlx.FilterableStatement: FilterableStatement object.
Raises:
ValueError: If ``offset`` is not a positive integer.
.. versionadded:: 8.0.12
"""
if not isinstance(offset, int) or offset < 0:
raise ValueError("The 'offset' value must be a positive integer")
self._limit_offset = offset
return self
def bind(self, *args: Any) -> FilterableStatement:
"""Binds value(s) to a specific placeholder(s).
Args:
*args: The name of the placeholder and the value to bind.
A :class:`mysqlx.DbDoc` object or a JSON string
representation can be used.
Returns:
mysqlx.FilterableStatement: FilterableStatement object.
Raises:
ProgrammingError: If the number of arguments is invalid.
"""
self.has_bindings = True
count = len(args)
if count == 1:
self._bind_single(args[0])
elif count == 2:
self._bindings[args[0]] = args[1]
else:
raise ProgrammingError("Invalid number of arguments to bind")
return self
def execute(self) -> Any:
"""Execute the statement.
Raises:
NotImplementedError: This method must be implemented.
"""
raise NotImplementedError
class SqlStatement(Statement):
"""A statement for SQL execution.
Args:
connection (mysqlx.connection.Connection): Connection object.
sql (string): The sql statement to be executed.
"""
def __init__(self, connection: ConnectionType, sql: str) -> None:
super().__init__(target=None, doc_based=False)
self._connection: ConnectionType = connection
self._sql: str = sql
self._binding_map: Optional[Dict[str, Any]] = None
self._bindings: Union[List, Tuple] = []
self.has_bindings: bool = False
self.has_limit: bool = False
@property
def sql(self) -> str:
"""string: The SQL text statement."""
return self._sql
def get_binding_map(self) -> Dict[str, Any]:
"""Returns the binding map dictionary.
Returns:
dict: The binding map dictionary.
"""
return self._binding_map
def get_bindings(self) -> Union[Tuple, List]:
"""Returns the bindings list.
Returns:
`list`: The bindings list.
"""
return self._bindings
def bind(self, *args: Any) -> SqlStatement:
"""Binds value(s) to a specific placeholder(s).
Args:
*args: The value(s) to bind.
Returns:
mysqlx.SqlStatement: SqlStatement object.
"""
if len(args) == 0:
raise ProgrammingError("Invalid number of arguments to bind")
self.has_bindings = True
bindings = flexible_params(*args)
if isinstance(bindings, (list, tuple)):
self._bindings = bindings
else:
self._bindings.append(bindings)
return self
def execute(self) -> SqlResult:
"""Execute the statement.
Returns:
mysqlx.SqlResult: SqlResult object.
"""
return self._connection.send_sql(self)
class WriteStatement(Statement):
"""Provide common write operation attributes."""
def __init__(self, target: DatabaseTargetType, doc_based: bool) -> None:
super().__init__(target, doc_based)
self._values: List[
Union[
int,
str,
DbDoc,
Dict[str, Any],
List[Optional[Union[str, int, float, ExprParser, Dict[str, Any]]]],
]
] = []
def get_values(
self,
) -> List[
Union[
int,
str,
DbDoc,
Dict[str, Any],
List[Optional[Union[str, int, float, ExprParser, Dict[str, Any]]]],
]
]:
"""Returns the list of values.
Returns:
`list`: The list of values.
"""
return self._values
def execute(self) -> Any:
"""Execute the statement.
Raises:
NotImplementedError: This method must be implemented.
"""
raise NotImplementedError
class AddStatement(WriteStatement):
"""A statement for document addition on a collection.
Args:
collection (mysqlx.Collection): The Collection object.
"""
def __init__(self, collection: DatabaseTargetType) -> None:
super().__init__(collection, True)
self._upsert: bool = False
self.ids: List = []
def is_upsert(self) -> bool:
"""Returns `True` if it's an upsert.
Returns:
bool: `True` if it's an upsert.
"""
return self._upsert
def upsert(self, value: bool = True) -> AddStatement:
"""Sets the upset flag to the boolean of the value provided.
Setting of this flag allows updating of the matched rows/documents
with the provided value.
Args:
value (optional[bool]): Set or unset the upsert flag.
"""
self._upsert = value
return self
def add(self, *values: DbDoc) -> AddStatement:
"""Adds a list of documents into a collection.
Args:
*values: The documents to be added into the collection.
Returns:
mysqlx.AddStatement: AddStatement object.
"""
for val in flexible_params(*values):
if isinstance(val, DbDoc):
self._values.append(val)
else:
self._values.append(DbDoc(val))
return self
def execute(self) -> Result:
"""Execute the statement.
Returns:
mysqlx.Result: Result object.
"""
if len(self._values) == 0:
return Result()
return self._connection.send_insert(self)
class UpdateSpec:
"""Update specification class implementation.
Args:
update_type (int): The update type.
source (str): The source.
value (Optional[str]): The value.
Raises:
ProgrammingError: If `source` is invalid.
"""
def __init__(self, update_type: int, source: str, value: Any = None) -> None:
if update_type == mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.SET"):
self._table_set(source, value)
else:
self.update_type: int = update_type
try:
self.source: Any = ExprParser(source, False).document_field().identifier
except ValueError as err:
raise ProgrammingError(f"{err}") from err
self.value: Any = value
def _table_set(self, source: str, value: Any) -> None:
"""Table set.
Args:
source (str): The source.
value (str): The value.
"""
self.update_type = mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.SET")
self.source = ExprParser(source, True).parse_table_update_field()
self.value = value
class ModifyStatement(FilterableStatement):
"""A statement for document update operations on a Collection.
Args:
collection (mysqlx.Collection): The Collection object.
condition (str): Sets the search condition to identify the documents
to be modified.
.. versionchanged:: 8.0.12
The ``condition`` parameter is now mandatory.
"""
def __init__(self, collection: DatabaseTargetType, condition: str) -> None:
super().__init__(target=collection, condition=condition)
self._update_ops: Dict[str, Any] = {}
def sort(self, *clauses: str) -> ModifyStatement:
"""Sets the sorting criteria.
Args:
*clauses: The expression strings defining the sort criteria.
Returns:
mysqlx.ModifyStatement: ModifyStatement object.
"""
return self._sort(*clauses)
def get_update_ops(self) -> Dict[str, Any]:
"""Returns the list of update operations.
Returns:
`list`: The list of update operations.
"""
return self._update_ops
def set(self, doc_path: str, value: Any) -> ModifyStatement:
"""Sets or updates attributes on documents in a collection.
Args:
doc_path (string): The document path of the item to be set.
value (string): The value to be set on the specified attribute.
Returns:
mysqlx.ModifyStatement: ModifyStatement object.
"""
self._update_ops[doc_path] = UpdateSpec(
mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.ITEM_SET"),
doc_path,
value,
)
self._changed = True
return self
@deprecated("8.0.12")
def change(self, doc_path: str, value: Any) -> ModifyStatement:
"""Add an update to the statement setting the field, if it exists at
the document path, to the given value.
Args:
doc_path (string): The document path of the item to be set.
value (object): The value to be set on the specified attribute.
Returns:
mysqlx.ModifyStatement: ModifyStatement object.
.. deprecated:: 8.0.12
"""
self._update_ops[doc_path] = UpdateSpec(
mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.ITEM_REPLACE"),
doc_path,
value,
)
self._changed = True
return self
def unset(self, *doc_paths: str) -> ModifyStatement:
"""Removes attributes from documents in a collection.
Args:
doc_paths (list): The list of document paths of the attributes to be
removed.
Returns:
mysqlx.ModifyStatement: ModifyStatement object.
"""
for item in flexible_params(*doc_paths):
self._update_ops[item] = UpdateSpec(
mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.ITEM_REMOVE"),
item,
)
self._changed = True
return self
def array_insert(self, field: str, value: Any) -> ModifyStatement:
"""Insert a value into the specified array in documents of a
collection.
Args:
field (string): A document path that identifies the array attribute
and position where the value will be inserted.
value (object): The value to be inserted.
Returns:
mysqlx.ModifyStatement: ModifyStatement object.
"""
self._update_ops[field] = UpdateSpec(
mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.ARRAY_INSERT"),
field,
value,
)
self._changed = True
return self
def array_append(self, doc_path: str, value: Any) -> ModifyStatement:
"""Inserts a value into a specific position in an array attribute in
documents of a collection.
Args:
doc_path (string): A document path that identifies the array
attribute and position where the value will be
inserted.
value (object): The value to be inserted.
Returns:
mysqlx.ModifyStatement: ModifyStatement object.
"""
self._update_ops[doc_path] = UpdateSpec(
mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.ARRAY_APPEND"),
doc_path,
value,
)
self._changed = True
return self
def patch(self, doc: Union[Dict, DbDoc, ExprParser, str]) -> ModifyStatement:
"""Takes a :class:`mysqlx.DbDoc`, string JSON format or a dict with the
changes and applies it on all matching documents.
Args:
doc (object): A generic document (DbDoc), string in JSON format or
dict, with the changes to apply to the matching
documents.
Returns:
mysqlx.ModifyStatement: ModifyStatement object.
"""
if doc is None:
doc = ""
if not isinstance(doc, (ExprParser, dict, DbDoc, str)):
raise ProgrammingError(
"Invalid data for update operation on document collection table"
)
self._update_ops["patch"] = UpdateSpec(
mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.MERGE_PATCH"),
"$",
doc.expr() if isinstance(doc, ExprParser) else doc,
)
self._changed = True
return self
def execute(self) -> Result:
"""Execute the statement.
Returns:
mysqlx.Result: Result object.
Raises:
ProgrammingError: If condition was not set.
"""
if not self.has_where:
raise ProgrammingError("No condition was found for modify")
return self._connection.send_update(self)
class ReadStatement(FilterableStatement):
"""Provide base functionality for Read operations
Args:
target (object): The target database object, it can be
:class:`mysqlx.Collection` or :class:`mysqlx.Table`.
doc_based (Optional[bool]): `True` if it is document based
(default: `True`).
condition (Optional[str]): Sets the search condition to filter
documents or records.
"""
def __init__(
self,
target: DatabaseTargetType,
doc_based: bool = True,
condition: Optional[str] = None,
) -> None:
super().__init__(target, doc_based, condition)
self._lock_exclusive: bool = False
self._lock_shared: bool = False
self._lock_contention: LockContention = LockContention.DEFAULT
@property
def lock_contention(self) -> LockContention:
""":class:`mysqlx.LockContention`: The lock contention value."""
return self._lock_contention
def _set_lock_contention(self, lock_contention: LockContention) -> None:
"""Set the lock contention.
Args:
lock_contention (:class:`mysqlx.LockContention`): Lock contention.
Raises:
ProgrammingError: If is an invalid lock contention value.
"""
try:
# Check if is a valid lock contention value
_ = LockContention(lock_contention.value)
except ValueError as err:
raise ProgrammingError(
"Invalid lock contention mode. Use 'NOWAIT' or 'SKIP_LOCKED'"
) from err
self._lock_contention = lock_contention
def is_lock_exclusive(self) -> bool:
"""Returns `True` if is `EXCLUSIVE LOCK`.
Returns:
bool: `True` if is `EXCLUSIVE LOCK`.
"""
return self._lock_exclusive
def is_lock_shared(self) -> bool:
"""Returns `True` if is `SHARED LOCK`.
Returns:
bool: `True` if is `SHARED LOCK`.
"""
return self._lock_shared
def lock_shared(
self, lock_contention: LockContention = LockContention.DEFAULT
) -> ReadStatement:
"""Execute a read operation with `SHARED LOCK`. Only one lock can be
active at a time.
Args:
lock_contention (:class:`mysqlx.LockContention`): Lock contention.
"""
self._lock_exclusive = False
self._lock_shared = True
self._set_lock_contention(lock_contention)
return self
def lock_exclusive(
self, lock_contention: LockContention = LockContention.DEFAULT
) -> ReadStatement:
"""Execute a read operation with `EXCLUSIVE LOCK`. Only one lock can be
active at a time.
Args:
lock_contention (:class:`mysqlx.LockContention`): Lock contention.
"""
self._lock_exclusive = True
self._lock_shared = False
self._set_lock_contention(lock_contention)
return self
def group_by(self, *fields: str) -> ReadStatement:
"""Sets a grouping criteria for the resultset.
Args:
*fields: The string expressions identifying the grouping criteria.
Returns:
mysqlx.ReadStatement: ReadStatement object.
"""
self._set_group_by(*fields)
return self
def having(self, condition: str) -> ReadStatement:
"""Sets a condition for records to be considered in agregate function
operations.
Args:
condition (string): A condition on the agregate functions used on
the grouping criteria.
Returns:
mysqlx.ReadStatement: ReadStatement object.
"""
self._set_having(condition)
return self
def execute(self) -> Union[DocResult, RowResult]:
"""Execute the statement.
Returns:
mysqlx.Result: Result object.
"""
return self._connection.send_find(self)
class FindStatement(ReadStatement):
"""A statement document selection on a Collection.
Args:
collection (mysqlx.Collection): The Collection object.
condition (Optional[str]): An optional expression to identify the
documents to be retrieved. If not specified
all the documents will be included on the
result unless a limit is set.
"""
def __init__(
self, collection: DatabaseTargetType, condition: Optional[str] = None
) -> None:
super().__init__(collection, True, condition)
def fields(self, *fields: str) -> FindStatement:
"""Sets a document field filter.
Args:
*fields: The string expressions identifying the fields to be
extracted.
Returns:
mysqlx.FindStatement: FindStatement object.
"""
return self._set_projection(*fields)
def sort(self, *clauses: str) -> FindStatement:
"""Sets the sorting criteria.
Args:
*clauses: The expression strings defining the sort criteria.
Returns:
mysqlx.FindStatement: FindStatement object.
"""
return self._sort(*clauses)
class SelectStatement(ReadStatement):
"""A statement for record retrieval operations on a Table.
Args:
table (mysqlx.Table): The Table object.
*fields: The fields to be retrieved.
"""
def __init__(self, table: DatabaseTargetType, *fields: str) -> None:
super().__init__(table, False)
self._set_projection(*fields)
def where(self, condition: str) -> SelectStatement:
"""Sets the search condition to filter.
Args:
condition (str): Sets the search condition to filter records.
Returns:
mysqlx.SelectStatement: SelectStatement object.
"""
return self._set_where(condition)
def order_by(self, *clauses: str) -> SelectStatement:
"""Sets the order by criteria.
Args:
*clauses: The expression strings defining the order by criteria.
Returns:
mysqlx.SelectStatement: SelectStatement object.
"""
return self._sort(*clauses)
def get_sql(self) -> str:
"""Returns the generated SQL.
Returns:
str: The generated SQL.
"""
where = f" WHERE {self._where_str}" if self.has_where else ""
group_by = f" GROUP BY {self._grouping_str}" if self.has_group_by else ""
having = f" HAVING {self._having}" if self.has_having else ""
order_by = f" ORDER BY {self._sort_str}" if self.has_sort else ""
limit = (
f" LIMIT {self._limit_row_count} OFFSET {self._limit_offset}"
if self.has_limit
else ""
)
stmt = (
f"SELECT {self._projection_str or '*'} "
f"FROM {self.schema.name}.{self.target.name}"
f"{where}{group_by}{having}{order_by}{limit}"
)
return stmt
class InsertStatement(WriteStatement):
"""A statement for insert operations on Table.
Args:
table (mysqlx.Table): The Table object.
*fields: The fields to be inserted.
"""
def __init__(self, table: DatabaseTargetType, *fields: Any) -> None:
super().__init__(table, False)
self._fields: Union[List, Tuple] = flexible_params(*fields)
def values(self, *values: Any) -> InsertStatement:
"""Set the values to be inserted.
Args:
*values: The values of the columns to be inserted.
Returns:
mysqlx.InsertStatement: InsertStatement object.
"""
self._values.append(list(flexible_params(*values)))
return self
def execute(self) -> Result:
"""Execute the statement.
Returns:
mysqlx.Result: Result object.
"""
return self._connection.send_insert(self)
class UpdateStatement(FilterableStatement):
"""A statement for record update operations on a Table.
Args:
table (mysqlx.Table): The Table object.
.. versionchanged:: 8.0.12
The ``fields`` parameters were removed.
"""
def __init__(self, table: DatabaseTargetType) -> None:
super().__init__(target=table, doc_based=False)
self._update_ops: Dict[str, Any] = {}
def where(self, condition: str) -> UpdateStatement:
"""Sets the search condition to filter.
Args:
condition (str): Sets the search condition to filter records.
Returns:
mysqlx.UpdateStatement: UpdateStatement object.
"""
return self._set_where(condition)
def order_by(self, *clauses: str) -> UpdateStatement:
"""Sets the order by criteria.
Args:
*clauses: The expression strings defining the order by criteria.
Returns:
mysqlx.UpdateStatement: UpdateStatement object.
"""
return self._sort(*clauses)
def get_update_ops(self) -> Dict[str, Any]:
"""Returns the list of update operations.
Returns:
`list`: The list of update operations.
"""
return self._update_ops
def set(self, field: str, value: Any) -> UpdateStatement:
"""Updates the column value on records in a table.
Args:
field (string): The column name to be updated.
value (object): The value to be set on the specified column.
Returns:
mysqlx.UpdateStatement: UpdateStatement object.
"""
self._update_ops[field] = UpdateSpec(
mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.SET"),
field,
value,
)
self._changed = True
return self
def execute(self) -> Result:
"""Execute the statement.
Returns:
mysqlx.Result: Result object
Raises:
ProgrammingError: If condition was not set.
"""
if not self.has_where:
raise ProgrammingError("No condition was found for update")
return self._connection.send_update(self)
class RemoveStatement(FilterableStatement):
"""A statement for document removal from a collection.
Args:
collection (mysqlx.Collection): The Collection object.
condition (str): Sets the search condition to identify the documents
to be removed.
.. versionchanged:: 8.0.12
The ``condition`` parameter was added.
"""
def __init__(self, collection: DatabaseTargetType, condition: str) -> None:
super().__init__(target=collection, condition=condition)
def sort(self, *clauses: str) -> RemoveStatement:
"""Sets the sorting criteria.
Args:
*clauses: The expression strings defining the sort criteria.
Returns:
mysqlx.FindStatement: FindStatement object.
"""
return self._sort(*clauses)
def execute(self) -> Result:
"""Execute the statement.
Returns:
mysqlx.Result: Result object.
Raises:
ProgrammingError: If condition was not set.
"""
if not self.has_where:
raise ProgrammingError("No condition was found for remove")
return self._connection.send_delete(self)
class DeleteStatement(FilterableStatement):
"""A statement that drops a table.
Args:
table (mysqlx.Table): The Table object.
.. versionchanged:: 8.0.12
The ``condition`` parameter was removed.
"""
def __init__(self, table: DatabaseTargetType) -> None:
super().__init__(target=table, doc_based=False)
def where(self, condition: str) -> DeleteStatement:
"""Sets the search condition to filter.
Args:
condition (str): Sets the search condition to filter records.
Returns:
mysqlx.DeleteStatement: DeleteStatement object.
"""
return self._set_where(condition)
def order_by(self, *clauses: str) -> DeleteStatement:
"""Sets the order by criteria.
Args:
*clauses: The expression strings defining the order by criteria.
Returns:
mysqlx.DeleteStatement: DeleteStatement object.
"""
return self._sort(*clauses)
def execute(self) -> Result:
"""Execute the statement.
Returns:
mysqlx.Result: Result object.
Raises:
ProgrammingError: If condition was not set.
"""
if not self.has_where:
raise ProgrammingError("No condition was found for delete")
return self._connection.send_delete(self)
class CreateCollectionIndexStatement(Statement):
"""A statement that creates an index on a collection.
Args:
collection (mysqlx.Collection): Collection.
index_name (string): Index name.
index_desc (dict): A dictionary containing the fields members that
constraints the index to be created. It must have
the form as shown in the following::
{"fields": [{"field": member_path,
"type": member_type,
"required": member_required,
"collation": collation,
"options": options,
"srid": srid},
# {... more members,
# repeated as many times
# as needed}
],
"type": type}
"""
def __init__(
self,
collection: DatabaseTargetType,
index_name: str,
index_desc: Dict[str, Any],
) -> None:
super().__init__(target=collection)
self._index_desc: Dict[str, Any] = copy.deepcopy(index_desc)
self._index_name: str = index_name
self._fields_desc: List[Dict[str, Any]] = self._index_desc.pop("fields", [])
def execute(self) -> Result:
"""Execute the statement.
Returns:
mysqlx.Result: Result object.
"""
# Validate index name is a valid identifier
if self._index_name is None:
raise ProgrammingError(ERR_INVALID_INDEX_NAME.format(self._index_name))
try:
parsed_ident = ExprParser(self._index_name).expr().get_message()
# The message is type dict when the Protobuf cext is used
if isinstance(parsed_ident, dict):
if parsed_ident["type"] != mysqlxpb_enum("Mysqlx.Expr.Expr.Type.IDENT"):
raise ProgrammingError(
ERR_INVALID_INDEX_NAME.format(self._index_name)
)
else:
if parsed_ident.type != mysqlxpb_enum("Mysqlx.Expr.Expr.Type.IDENT"):
raise ProgrammingError(
ERR_INVALID_INDEX_NAME.format(self._index_name)
)
except (ValueError, AttributeError) as err:
raise ProgrammingError(
ERR_INVALID_INDEX_NAME.format(self._index_name)
) from err
# Validate members that constraint the index
if not self._fields_desc:
raise ProgrammingError(
"Required member 'fields' not found in the given index "
f"description: {self._index_desc}"
)
if not isinstance(self._fields_desc, list):
raise ProgrammingError("Required member 'fields' must contain a list")
args: Dict[str, Any] = {}
args["name"] = self._index_name
args["collection"] = self._target.name
args["schema"] = self._target.schema.name
if "type" in self._index_desc:
args["type"] = self._index_desc.pop("type")
else:
args["type"] = "INDEX"
args["unique"] = self._index_desc.pop("unique", False)
# Currently unique indexes are not supported:
if args["unique"]:
raise NotSupportedError("Unique indexes are not supported.")
args["constraint"] = []
if self._index_desc:
raise ProgrammingError(f"Unidentified fields: {self._index_desc}")
try:
for field_desc in self._fields_desc:
constraint = {}
constraint["member"] = field_desc.pop("field")
constraint["type"] = field_desc.pop("type")
constraint["required"] = field_desc.pop("required", False)
constraint["array"] = field_desc.pop("array", False)
if not isinstance(constraint["required"], bool):
raise TypeError("Field member 'required' must be Boolean")
if not isinstance(constraint["array"], bool):
raise TypeError("Field member 'array' must be Boolean")
if args["type"].upper() == "SPATIAL" and not constraint["required"]:
raise ProgrammingError(
"Field member 'required' must be set to 'True' when "
"index type is set to 'SPATIAL'"
)
if args["type"].upper() == "INDEX" and constraint["type"] == "GEOJSON":
raise ProgrammingError(
"Index 'type' must be set to 'SPATIAL' when field "
"type is set to 'GEOJSON'"
)
if "collation" in field_desc:
if not constraint["type"].upper().startswith("TEXT"):
raise ProgrammingError(
"The 'collation' member can only be used when "
"field type is set to "
f"'{constraint['type'].upper()}'"
)
constraint["collation"] = field_desc.pop("collation")
# "options" and "srid" fields in IndexField can be
# present only if "type" is set to "GEOJSON"
if "options" in field_desc:
if constraint["type"].upper() != "GEOJSON":
raise ProgrammingError(
"The 'options' member can only be used when "
"index type is set to 'GEOJSON'"
)
constraint["options"] = field_desc.pop("options")
if "srid" in field_desc:
if constraint["type"].upper() != "GEOJSON":
raise ProgrammingError(
"The 'srid' member can only be used when index "
"type is set to 'GEOJSON'"
)
constraint["srid"] = field_desc.pop("srid")
args["constraint"].append(constraint)
except KeyError as err:
raise ProgrammingError(
f"Required inner member {err} not found in constraint: {field_desc}"
) from err
for field_desc in self._fields_desc:
if field_desc:
raise ProgrammingError(f"Unidentified inner fields: {field_desc}")
return self._connection.execute_nonquery(
"mysqlx", "create_collection_index", True, args
)