core/lib/sqlparse/models.py (596 lines of code) (raw):
"""
Copyright (c) 2017-present, Facebook, Inc.
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import hashlib
import logging
import re
from typing import List, NamedTuple, Optional, Set, Union
log = logging.getLogger(__name__)
def escape(keyword):
"""
Escape the backtick for keyword when generating an actual SQL.
"""
return keyword.replace("`", "``")
def is_equal(left, right):
"""
If both left and right are None, then they are equal because both haven't
been initialized yet.
If only one of them is None, they they are not equal
If both of them is not None, then it's possible they are equal, and we'll
return True and do some more comparision later
"""
if left is not None and right is not None:
# Neither of them is None
if left != right:
return False
else:
return True
elif left is None and right is not None:
# Only left is None
return False
elif left is not None and right is None:
# Only right is None
return False
else:
# Both of them are None
return True
class IndexColumn(object):
"""
A column definition inside index section.
This is different from a table column definition, because only `name`,
`length`, `order` are required for a index column definition
"""
def __init__(self):
self.name = None
self.length = None
self.order = "ASC"
def __str__(self):
str_repr = ""
if self.length is not None:
str_repr = "{}({})".format(self.name, self.length)
else:
str_repr = "{}".format(self.name)
if self.order != "ASC":
str_repr += " DESC"
return str_repr
def __eq__(self, other):
if self.name != other.name:
return False
return self.length == other.length and self.order == other.order
def __ne__(self, other):
return not self == other
def to_sql(self):
sql_str = ""
if self.length is not None:
sql_str = "`{}`({})".format(escape(self.name), self.length)
else:
sql_str = "`{}`".format(escape(self.name))
if self.order != "ASC":
sql_str += " DESC"
return sql_str
class DocStoreIndexColumn(object):
"""
A column definition inside index section for DocStore.
DocStore index column has more attributes than the normal one
"""
def __init__(self):
self.document_path = None
self.key_type = None
self.length = None
def __str__(self):
if self.length is not None:
return "{} AS {}({})".format(self.document_path, self.key_type, self.length)
else:
return "{} AS {}".format(self.document_path, self.key_type)
def __eq__(self, other):
for attr in ("document_path", "key_type", "length"):
if not is_equal(getattr(self, attr), getattr(other, attr)):
return False
return True
def __ne__(self, other):
return not self == other
def to_sql(self):
if self.length is not None:
return "{} AS {}({})".format(self.document_path, self.key_type, self.length)
else:
return "{} AS {}".format(self.document_path, self.key_type)
class TableIndex(object):
"""
An index definition. This can defined either directly after single column
definition or after all column definitions
"""
def __init__(self, name=None, is_unique=False):
self.name = name
self.key_block_size = None
self.comment = None
self.is_unique = is_unique
self.key_type = None
self.using = None
self.column_list = []
def __str__(self):
idx_str = []
idx_str.append("NAME: {}".format(self.name))
idx_str.append("IS UNIQUE: {}".format(self.is_unique))
idx_str.append("TYPE: {}".format(self.key_type))
col_list_str = []
for col_str in self.column_list:
col_list_str.append(str(col_str))
idx_str.append("KEY LIST: {}".format(",".join(col_list_str)))
if self.using:
idx_str.append("USING: {}".format(self.using))
idx_str.append("KEY_BLOCK_SIZE: {}".format(self.key_block_size))
idx_str.append("COMMENT: {}".format(self.comment))
return "/ ".join(idx_str)
def __eq__(self, other):
for attr in (
"name",
"key_block_size",
"comment",
"is_unique",
"key_type",
"using",
):
if not is_equal(getattr(self, attr), getattr(other, attr)):
return False
return self.column_list == other.column_list
def __ne__(self, other):
return not self == other
def to_sql(self):
segments = []
if self.name is not None:
if self.name == "PRIMARY":
segments.append("PRIMARY KEY")
else:
if self.is_unique:
segments.append("UNIQUE KEY `{}`".format(escape(self.name)))
elif self.key_type is not None:
segments.append(
"{} KEY `{}`".format(self.key_type, escape(self.name))
)
else:
segments.append("KEY `{}`".format(escape(self.name)))
else:
segments.append("KEY")
segments.append(
"({})".format(", ".join([col.to_sql() for col in self.column_list]))
)
if self.using is not None:
segments.append("USING {}".format(self.using))
if self.key_block_size is not None:
segments.append("KEY_BLOCK_SIZE={}".format(self.key_block_size))
if self.comment is not None:
segments.append("COMMENT {}".format(self.comment))
return " ".join(segments)
class Column(object):
"""
Representing a column definiton in a table
"""
def __init__(self):
self.name = None
self.column_type = None
self.default = None
self.charset = None
self.collate = None
self.length = None
self.comment = None
self.nullable = True
self.unsigned = None
self.is_default_bit = False
self.auto_increment = None
def __str__(self):
col_str = []
col_str.append("NAME: {}".format(self.name))
col_str.append("TYPE: {}".format(self.column_type))
if self.is_default_bit:
col_str.append("DEFAULT: b'{}'".format(self.default))
else:
col_str.append("DEFAULT: {}".format(self.default))
col_str.append("LENGTH: {}".format(self.length))
col_str.append("CHARSET: {}".format(self.charset))
col_str.append("COLLATE: {}".format(self.collate))
col_str.append("NULLABLE: {}".format(self.nullable))
col_str.append("UNSIGNED: {}".format(self.unsigned))
col_str.append("COMMENT: {}".format(self.comment))
return " ".join(col_str)
@property
def quoted_default(self):
"""
Quote the default value if it's a numeric string. This is how MySQL
does when you execute it without quotes
"""
try:
float(self.default)
return "'{}'".format(self.default)
except (ValueError, TypeError):
return self.default
def __eq__(self, other):
for attr in (
"name",
"column_type",
"charset",
"collate",
"length",
"comment",
"nullable",
"unsigned",
"is_default_bit",
"auto_increment",
):
# Ignore display width of *int types, because of the new default in 8.0.20.
# This is a bit of a heavy hammer, but it's the simpler alternative to be
# able to support mixed version comparisons
# Ref: https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-19.html
# (search for: "Display width specification for integer data types")
int_types = {"int", "bigint", "tinyint", "smallint", "mediumint"}
if self.column_type.lower() in int_types and attr == "length":
continue
if not is_equal(getattr(self, attr), getattr(other, attr)):
return False
return self.has_same_default(other)
def has_same_default(self, other):
# nullable column has implicit default as null
if self.nullable:
if self.quoted_default != other.quoted_default:
# Implicit NULL equals to explicit default NULL
# Other than that if there's any difference between two
# default values they are semanticly different
left_default_is_null = (
self.default is None or self.default.upper() == "NULL"
)
right_default_is_null = (
other.default is None or other.default.upper() == "NULL"
)
if not (left_default_is_null and right_default_is_null):
return False
else:
if self.quoted_default != other.quoted_default:
return False
return True
def __ne__(self, other):
return not self == other
def to_sql(self):
column_segment = []
column_segment.append("`{}`".format(escape(self.name)))
if self.length is not None:
column_segment.append("{}({})".format(self.column_type, self.length))
else:
column_segment.append("{}".format(self.column_type))
if self.charset is not None:
column_segment.append("CHARACTER SET {}".format(self.charset))
if self.unsigned is not None:
column_segment.append("UNSIGNED")
if self.collate is not None:
column_segment.append("COLLATE {}".format(self.collate))
# By default MySQL will implicitly make column as nullable if not
# specified
if self.nullable or self.nullable is None:
column_segment.append("NULL")
else:
column_segment.append("NOT NULL")
if self.default is not None:
if self.is_default_bit:
column_segment.append("DEFAULT b{}".format(self.default))
else:
column_segment.append("DEFAULT {}".format(self.default))
if self.auto_increment is not None:
column_segment.append("AUTO_INCREMENT")
if self.comment is not None:
column_segment.append("COMMENT {}".format(self.comment))
return " ".join(column_segment)
class TimestampColumn(Column):
"""
A timestamp type column. It's different from other type of columns because
it allow CURRENT_TIMESTAMP as a default value, and has a special attribute
called "ON UPDATE"
"""
def __init__(self):
super(TimestampColumn, self).__init__()
self.on_update_current_timestamp = None
# We will not use the default nullable=True here, because timestamp
# default behaviour is special
self.nullable = None
def __str__(self):
col_str = super(TimestampColumn, self).__str__()
col_str += " ON UPDATE: {}".format(self.on_update_current_timestamp)
return col_str
def explicit_ts_default(self):
""" "
This is a special case for TimeStamp.
If you define a column as
`col` timestamp
it has the exact the same meaning as
`col` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON
UPDATE CURRENT_TIMESTAMP
See also:
http://dev.mysql.com/doc/refman/5.6/en/timestamp-initialization.html
"""
if self.column_type == "TIMESTAMP":
if all(
[
(self.nullable is None or not self.nullable),
self.default is None,
self.on_update_current_timestamp is None,
]
):
self.nullable = False
self.default = "CURRENT_TIMESTAMP"
self.on_update_current_timestamp = "CURRENT_TIMESTAMP"
else:
# Except timestamp, all other types have the implicit nullable
# behavior.
if self.nullable is None:
self.nullable = True
def __eq__(self, other):
self.explicit_ts_default()
if getattr(other, "explicit_ts_default", None):
other.explicit_ts_default()
if not super(TimestampColumn, self).__eq__(other):
return False
for attr in ("on_update_current_timestamp",):
if not is_equal(getattr(self, attr), getattr(other, attr)):
return False
return True
def __ne__(self, other):
return not self == other
def to_sql(self):
self.explicit_ts_default()
column_segment = []
column_segment.append("`{}`".format(escape(self.name)))
if self.length is not None:
column_segment.append("{}({})".format(self.column_type, self.length))
else:
column_segment.append("{}".format(self.column_type))
if self.nullable:
column_segment.append("NULL")
else:
column_segment.append("NOT NULL")
if self.default is not None:
column_segment.append("DEFAULT {}".format(self.default))
if self.on_update_current_timestamp is not None:
column_segment.append(
"ON UPDATE {}".format(self.on_update_current_timestamp)
)
if self.comment is not None:
column_segment.append("COMMENT {}".format(self.comment))
return " ".join(column_segment)
class SetColumn(Column):
"""
A set type column. It's different from other type of columns because it
has a list of allowed values for definition
"""
def __init__(self):
super(SetColumn, self).__init__()
self.set_list = []
def __str__(self):
col_str = super(SetColumn, self).__str__()
col_str += " SET VALUES: [{}]".format(", ".join(self.set_list))
return col_str
def __eq__(self, other):
if not super(SetColumn, self).__eq__(other):
return False
return self.set_list == other.set_list
def __ne__(self, other):
return not self == other
def to_sql(self):
column_segment = []
column_segment.append("`{}`".format(escape(self.name)))
column_segment.append(
"{}({})".format(self.column_type, ", ".join(self.set_list))
)
if self.nullable:
column_segment.append("NULL")
else:
column_segment.append("NOT NULL")
if self.default is not None:
column_segment.append("DEFAULT {}".format(self.default))
if self.comment is not None:
column_segment.append("COMMENT {}".format(self.comment))
return " ".join(column_segment)
class EnumColumn(Column):
"""
A enum type column. It's different from other type of columns because it
has a list of allowed values for definition
"""
def __init__(self):
super(EnumColumn, self).__init__()
self.enum_list = []
def __str__(self):
col_str = super(EnumColumn, self).__str__()
col_str += "ENUM VALUES: [{}]".format(", ".join(self.enum_list))
return col_str
def __eq__(self, other):
if not super(EnumColumn, self).__eq__(other):
return False
return self.enum_list == other.enum_list
def __ne__(self, other):
return not self == other
def to_sql(self):
column_segment = []
column_segment.append("`{}`".format(escape(self.name)))
column_segment.append(
"{}({})".format(self.column_type, ", ".join(self.enum_list))
)
if self.charset is not None:
column_segment.append("CHARACTER SET {}".format(self.charset))
if self.collate is not None:
column_segment.append("COLLATE {}".format(self.collate))
if self.nullable:
column_segment.append("NULL")
else:
column_segment.append("NOT NULL")
if self.default is not None:
column_segment.append("DEFAULT {}".format(self.default))
if self.comment is not None:
column_segment.append("COMMENT {}".format(self.comment))
return " ".join(column_segment)
class PartitionDefinitionEntry(NamedTuple):
pdef_name: str
pdef_type: str
pdef_value_list: Union[List[str], str]
pdef_comment: Optional[str]
pdef_engine: str = "INNODB"
is_tuple: bool = False
class PartitionConfig:
# Partitions config for a table
PTYPE_RANGE = "RANGE"
PTYPE_LIST = "LIST"
PTYPE_HASH = "HASH"
PTYPE_KEY = "KEY"
SUBTYPE_L = "LINEAR"
SUBTYPE_C = "COLUMNS"
KNOWN_PARTITION_TYPES: Set[str] = {PTYPE_LIST, PTYPE_HASH, PTYPE_KEY, PTYPE_RANGE}
KNOWN_PARTITION_SUBTYPES: Set[str] = {SUBTYPE_L, SUBTYPE_C}
PDEF_TYPE_VIN = "p_values_in"
PDEF_TYPE_VLT = "p_values_less_than"
PDEF_TYPE_ATTRIBS: List[str] = [PDEF_TYPE_VIN, PDEF_TYPE_VLT]
TYPE_MAP = {
PDEF_TYPE_VIN: "IN",
PDEF_TYPE_VLT: "LESS THAN",
}
def __init__(self) -> None:
self.part_type: Optional[str] = None # Partition type e.g. RANGE
self.p_subtype: Optional[str] = None # e.g. LINEAR / COLUMNS
self.num_partitions: int = 0
self.fields_or_expr: Optional[Union[str, List[str]]] = None
self.part_defs: List[PartitionDefinitionEntry] = []
self.full_type: str = ""
# Partition type `KEY` alone allows specifying ALGORITHM=[1|2] e.g.
# `PARTITION BY linear key ALGORITHM=2 (id) partitions 10`
self.algorithm_for_key: Optional[int] = None
self.via_nested_expr = False
def __str__(self):
return (
f"{self.__class__.__name__}: |"
f"type={self.full_type}|"
f"fields_or_expr={self.fields_or_expr}|"
f"defs={self.part_defs}|numparts={self.num_partitions}"
)
def get_type(self) -> Optional[str]:
return self.full_type
def get_num_parts(self) -> int:
return self.num_partitions
def get_fields_or_expr(self) -> Optional[Union[str, List[str]]]:
return self.fields_or_expr
def get_algo(self) -> Optional[int]:
return self.algorithm_for_key if self.part_type == self.PTYPE_KEY else None
def __eq__(self, other):
for attr in (
"part_type",
"p_subtype",
"num_partitions",
"fields_or_expr",
"full_type",
"algorithm_for_key",
):
if not is_equal(getattr(self, attr), getattr(other, attr)):
return False
return self.part_defs == other.part_defs
def __ne__(self, other):
return not self == other
def add_quote(self, field: str) -> str:
return f"`{field}`"
def to_partial_sql(self):
# Stringify info a format usable in `create table ...`
def _proc_list(vals: Union[str, List[str]]) -> str:
# Helper to convert expr list to an expression value-list
if isinstance(vals, list) and all(isinstance(v, str) for v in vals):
return "(" + ", ".join(vals) + ")"
ret = ""
for v in vals:
if isinstance(v, list):
ret += _proc_list(v)
else:
ret += v
return ret
output = f"PARTITION BY {self.full_type}"
if self.part_type == self.PTYPE_KEY:
if self.algorithm_for_key is not None:
output += f" ALGORITHM={self.algorithm_for_key}"
fields = ", ".join(self.add_quote(f) for f in self.fields_or_expr)
output += f" ({fields})"
if self.num_partitions > 1:
output += f" PARTITIONS {self.num_partitions}"
return output
elif self.part_type == self.PTYPE_HASH:
output += f" ({_proc_list(self.fields_or_expr)})"
if self.num_partitions > 1:
output += f" PARTITIONS {self.num_partitions}"
return output
elif self.part_type == self.PTYPE_RANGE or self.part_type == self.PTYPE_LIST:
partitions: List[str] = []
for pd in self.part_defs:
name = f"`{pd.pdef_name}`" if pd.pdef_name.isdigit() else pd.pdef_name
ty = self.TYPE_MAP[pd.pdef_type]
expr_or_value_list = (
_proc_list(pd.pdef_value_list)
if isinstance(pd.pdef_value_list, list)
else pd.pdef_value_list
)
eng = pd.pdef_engine
if pd.is_tuple:
expr_or_value_list = f"({expr_or_value_list})"
thispart = (
f"PARTITION {name} VALUES {ty} {expr_or_value_list} ENGINE {eng}"
)
comment = pd.pdef_comment
if comment is not None:
thispart += f" COMMENT {comment}"
partitions.append(thispart)
f_or_e = _proc_list(self.fields_or_expr)
if self.via_nested_expr:
# PART_EXPR in sqlparse use nestedExpr to acquire this
# and strips parens so "undo" that
f_or_e = f"({f_or_e})"
output += f" {f_or_e} (\n" + ",\n".join(partitions) + ")"
return output
class Table(object):
"""
Representing a table definiton
"""
def __init__(self):
self.table_options = []
self.name = None
self.engine = None
self.charset = None
self.collate = None
self.row_format = None
self.key_block_size = None
self.compression = None
self.auto_increment = None
self.comment = None
self.column_list = []
self.primary_key = TableIndex(name="PRIMARY", is_unique=True)
self.indexes = []
self.partition = None # Partitions as a string
self.constraint = None
self.partition_config: Optional[PartitionConfig] = None
self.has_80_features = False
def __str__(self):
table_str = ""
table_str += "NAME: {}\n".format(self.name)
table_str += "ENGINE: {}\n".format(self.engine)
table_str += "CHARSET: {}\n".format(self.charset)
table_str += "COLLATE: {}\n".format(self.collate)
table_str += "ROW_FORMAT: {}\n".format(self.row_format)
table_str += "KEY_BLOCK_SIZE: {}\n".format(self.key_block_size)
table_str += "COMPRESSION: {}\n".format(self.compression)
table_str += "AUTO_INCREMENT: {}\n".format(self.auto_increment)
table_str += "COMMENT: {}\n".format(self.comment)
table_str += "PARTITION: {}\n".format(self.partition)
for col in self.column_list:
table_str += "[{}]\n".format(str(col))
table_str += "PRIMARY_KEYS: \n"
table_str += "\t{}\n".format(str(self.primary_key))
table_str += "INDEXES: \n"
for index in self.indexes:
table_str += "\t{}\n".format(str(index))
table_str += "Constraint: {}".format(str(self.constraint))
return table_str
def __eq__(self, other):
for attr in (
"name",
"engine",
"charset",
"collate",
"row_format",
"key_block_size",
"comment",
# "partition",
"partition_config",
):
if not is_equal(getattr(self, attr), getattr(other, attr)):
return False
if self.primary_key != other.primary_key:
return False
for idx in self.indexes:
if idx not in other.indexes:
return False
for idx in other.indexes:
if idx not in self.indexes:
return False
if self.column_list != other.column_list:
return False
# If we get to this point, the two table structures are identical
return True
def __ne__(self, other):
return not self == other
def to_sql(self):
"""
A standardize CREATE TABLE statement for creating the table
"""
sql = "CREATE TABLE `{}` (\n".format(escape(self.name))
col_strs = []
for column in self.column_list:
col_strs.append(" " + column.to_sql())
sql += ",\n".join(col_strs)
if self.primary_key.column_list:
sql += ",\n {}".format(self.primary_key.to_sql())
if self.indexes:
for idx in self.indexes:
sql += ",\n " + idx.to_sql()
sql += "\n) "
if self.engine is not None:
sql += "ENGINE={} ".format(self.engine)
if self.auto_increment is not None:
sql += "AUTO_INCREMENT={} ".format(self.auto_increment)
if self.charset is not None:
sql += "DEFAULT CHARSET={} ".format(self.charset)
if self.collate is not None:
sql += "COLLATE={} ".format(self.collate)
if self.row_format is not None:
sql += "ROW_FORMAT={} ".format(self.row_format)
if self.key_block_size is not None:
sql += "KEY_BLOCK_SIZE={} ".format(self.key_block_size)
if self.compression is not None:
sql += "COMPRESSION={} ".format(self.compression)
if self.comment is not None:
sql += "COMMENT={} ".format(self.comment)
if self.partition is not None:
sql += "\n{} ".format(self.partition)
return sql
@property
def checksum(self):
"""
Generate a MD5 hash for the schema that this object stands for.
In theory, two identical table shcema should have the exact same
create table statement after standardize, and their MD5 hash should be
the same as well.
So you can tell whether two schema has the same structure by comparing
their checksum value.
"""
md5_obj = hashlib.md5(self.to_sql().encode("utf-8"))
return md5_obj.hexdigest()
def droppable_indexes(self, keep_unique_key=False):
"""
Drop index before loading, and create afterwards can make the whole
process faster. Also the indexes will be more compact than loading
directly.
This function will return a list of droppable indexes for the
purpose of fast index recreation.
@param keep_unique_key: Keep unique key or not
@type keep_unique_key: bool
@return: a list of droppable indexes to make load faster
@rtype : [TableIndex]
"""
# Primary key should not be dropped, but it's not included in
# table.indexes so we are fine here
idx_list = []
auto_incre_name = ""
for col in self.column_list:
if col.auto_increment:
auto_incre_name = col.name
break
for idx in self.indexes:
# Drop index which contains only the auto_increment column is
# not allowed
if len(idx.column_list) == 1:
if auto_incre_name and auto_incre_name == idx.column_list[0].name:
continue
# We can drop unique index for most of the time. Only if we want
# to ignore duplicate key when adding new unique indexes, we need
# to have the index exist on new table before loading data. So that
# we can utilize "LOAD IGNORE" to ignore the duplicated data
if keep_unique_key and idx.is_unique:
continue
idx_list.append(idx)
return idx_list
@property
def is_myrocks_ttl_table(self):
if not self.engine:
return False
if self.engine.upper() == "ROCKSDB":
if self.comment:
# partition level ttl
if re.search(r"\S+_ttl_duration=[0-9]+;", self.comment):
return True
# table level ttl
elif re.search(r"ttl_duration=[0-9]+;", self.comment):
return True
else:
return False
else:
return False
else:
return False