# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from cassandra.metadata import maybe_escape_name
from cqlshlib import helptopics
from cqlshlib.cqlhandling import CqlParsingRuleSet, Hint

simple_cql_types = {'ascii', 'bigint', 'blob', 'boolean', 'counter', 'date', 'decimal', 'double', 'duration', 'float',
                    'inet', 'int', 'smallint', 'text', 'time', 'timestamp', 'timeuuid', 'tinyint', 'uuid', 'varchar',
                    'varint'}
simple_cql_types.difference_update(('set', 'map', 'list', 'vector'))


class UnexpectedTableStructure(UserWarning):

    def __init__(self, msg):
        self.msg = msg

    def __str__(self):
        return 'Unexpected table structure; may not translate correctly to CQL. ' + self.msg


SYSTEM_KEYSPACES = ('system', 'system_schema', 'system_traces', 'system_auth', 'system_distributed', 'system_views',
                    'system_metrics', 'system_virtual_schema', 'system_cluster_metadata')
NONALTERBALE_KEYSPACES = ('system', 'system_schema', 'system_views', 'system_metrics', 'system_virtual_schema',
                          'system_cluster_metadata')


class Cql3ParsingRuleSet(CqlParsingRuleSet):

    columnfamily_layout_options = (
        ('allow_auto_snapshot', None),
        ('bloom_filter_fp_chance', None),
        ('comment', None),
        ('gc_grace_seconds', None),
        ('incremental_backups', None),
        ('min_index_interval', None),
        ('max_index_interval', None),
        ('default_time_to_live', None),
        ('speculative_retry', None),
        ('additional_write_policy', None),
        ('memtable', None),
        ('memtable_flush_period_in_ms', None),
        ('cdc', None),
        ('read_repair', None),
    )

    columnfamily_layout_map_options = (
        # (CQL3 option name, schema_columnfamilies column name (or None if same),
        #  list of known map keys)
        ('compaction', 'compaction_strategy_options',
            ('class', 'max_threshold', 'tombstone_compaction_interval', 'tombstone_threshold', 'enabled',
             'unchecked_tombstone_compaction', 'only_purge_repaired_tombstones', 'provide_overlapping_tombstones')),
        ('compression', 'compression_parameters',
            ('class', 'chunk_length_in_kb', 'enabled', 'min_compress_ratio', 'max_compressed_length')),
        ('caching', None,
            ('rows_per_partition', 'keys')),
    )

    obsolete_cf_options = ()

    consistency_levels = (
        'ANY',
        'ONE',
        'TWO',
        'THREE',
        'QUORUM',
        'ALL',
        'LOCAL_QUORUM',
        'EACH_QUORUM',
        'SERIAL'
    )

    size_tiered_compaction_strategy_options = (
        'min_sstable_size',
        'min_threshold',
        'bucket_high',
        'bucket_low'
    )

    leveled_compaction_strategy_options = (
        'sstable_size_in_mb',
        'fanout_size'
    )

    time_window_compaction_strategy_options = (
        'compaction_window_unit',
        'compaction_window_size',
        'min_threshold',
        'timestamp_resolution'
    )

    unified_compaction_strategy_options = (
        'scaling_parameters',
        'min_sstable_size',
        'flush_size_override',
        'base_shard_count',
        'target_sstable_size',
        'sstable_growth',
        'max_sstables_to_compact',
        'expired_sstable_check_frequency_seconds',
        'unsafe_aggressive_sstable_expiration',
        'overlap_inclusion_method'
    )

    @classmethod
    def escape_value(cls, value):
        if value is None:
            return 'NULL'  # this totally won't work
        if isinstance(value, bool):
            value = str(value).lower()
        elif isinstance(value, float):
            return '%f' % value
        elif isinstance(value, int):
            return str(value)
        return "'%s'" % value.replace("'", "''")

    @classmethod
    def escape_name(cls, name):
        if name is None:
            return 'NULL'
        return "'%s'" % name.replace("'", "''")

    @staticmethod
    def dequote_name(name):
        name = name.strip()
        if name == '':
            return name
        if name[0] == '"' and name[-1] == '"':
            return name[1:-1].replace('""', '"')
        else:
            return name.lower()

    @staticmethod
    def dequote_value(cqlword):
        cqlword = cqlword.strip()
        if cqlword == '':
            return cqlword
        if cqlword[0] == "'" and cqlword[-1] == "'":
            cqlword = cqlword[1:-1].replace("''", "'")
        return cqlword


CqlRuleSet = Cql3ParsingRuleSet()

# convenience for remainder of module
completer_for = CqlRuleSet.completer_for
explain_completion = CqlRuleSet.explain_completion
dequote_value = CqlRuleSet.dequote_value
dequote_name = CqlRuleSet.dequote_name
escape_value = CqlRuleSet.escape_value

# BEGIN SYNTAX/COMPLETION RULE DEFINITIONS

syntax_rules = r'''
<Start> ::= <CQL_Statement>*
          ;

<CQL_Statement> ::= [statements]=<statementBody> ";"
                  ;

# The order of these terminal productions is significant. The input string is matched to the rule
# specified first in the grammar.

<endline> ::= /\n/ ;

JUNK ::= /([ \t\r\f\v]+|(--|[/][/])[^\n\r]*([\n\r]|$)|[/][*].*?[*][/])/ ;

<stringLiteral> ::= <quotedStringLiteral>
                  | <pgStringLiteral> ;
<quotedStringLiteral> ::= /'([^']|'')*'/ ;
<pgStringLiteral> ::= /\$\$(?:(?!\$\$).)*\$\$/;
<quotedName> ::=    /"([^"]|"")*"/ ;

<unclosedPgString>::= /\$\$(?:(?!\$\$).)*/ ;
<unclosedString>  ::= /'([^']|'')*/ ;
<unclosedName>    ::= /"([^"]|"")*/ ;
<unclosedComment> ::= /[/][*].*$/ ;

<float> ::=         /-?[0-9]+\.[0-9]+/ ;
<uuid> ::=          /[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}/ ;
<blobLiteral> ::=    /0x[0-9a-f]+/ ;
<wholenumber> ::=   /[0-9]+/ ;
<identifier> ::=    /[a-z][a-z0-9_]*/ ;
<colon> ::=         ":" ;
<star> ::=          "*" ;
<endtoken> ::=      ";" ;
<op> ::=            /[-+=%/,().]/ ;
<cmp> ::=           /[<>!]=?/ ;
<brackets> ::=      /[][{}]/ ;

<integer> ::= "-"? <wholenumber> ;
<boolean> ::= "true"
            | "false"
            ;

<term> ::= <stringLiteral>
         | <integer>
         | <float>
         | <uuid>
         | <boolean>
         | <blobLiteral>
         | <collectionLiteral>
         | <functionLiteral> <functionArguments>
         | "NULL"
         ;

<functionLiteral> ::= (<identifier> ( "." <identifier> )?)
                 | "TOKEN"
                 ;

<functionArguments> ::= "(" ( <term> ( "," <term> )* )? ")"
                 ;

<tokenDefinition> ::= token="TOKEN" "(" <term> ( "," <term> )* ")"
                    | <term>
                    ;
<cident> ::= <quotedName>
           | <identifier>
           | <unreservedKeyword>
           ;
<colname> ::= <cident> ;   # just an alias

<collectionLiteral> ::= <listLiteral>
                      | <setLiteral>
                      | <mapLiteral>
                      ;
<listLiteral> ::= "[" ( <term> ( "," <term> )* )? "]"
                ;
<setLiteral> ::= "{" ( <term> ( "," <term> )* )? "}"
               ;
<mapLiteral> ::= "{" <term> ":" <term> ( "," <term> ":" <term> )* "}"
               ;

<anyFunctionName> ::= ( ksname=<cfOrKsName> dot="." )? udfname=<cfOrKsName> ;

<userFunctionName> ::= ( ksname=<nonSystemKeyspaceName> dot="." )? udfname=<cfOrKsName> ;

<refUserFunctionName> ::= udfname=<cfOrKsName> ;

<userAggregateName> ::= ( ksname=<nonSystemKeyspaceName> dot="." )? udaname=<cfOrKsName> ;

<functionAggregateName> ::= ( ksname=<nonSystemKeyspaceName> dot="." )? functionname=<cfOrKsName> ;

<aggregateName> ::= <userAggregateName>
                  ;

<functionName> ::= <functionAggregateName>
                 | "TOKEN"
                 ;

<statementBody> ::= <useStatement>
                  | <selectStatement>
                  | <dataChangeStatement>
                  | <schemaChangeStatement>
                  | <authenticationStatement>
                  | <authorizationStatement>
                  ;

<dataChangeStatement> ::= <insertStatement>
                        | <updateStatement>
                        | <deleteStatement>
                        | <truncateStatement>
                        | <batchStatement>
                        ;

<schemaChangeStatement> ::= <createKeyspaceStatement>
                          | <createColumnFamilyStatement>
                          | <copyTableStatement>
                          | <createIndexStatement>
                          | <createMaterializedViewStatement>
                          | <createUserTypeStatement>
                          | <createFunctionStatement>
                          | <createAggregateStatement>
                          | <createTriggerStatement>
                          | <dropKeyspaceStatement>
                          | <dropColumnFamilyStatement>
                          | <dropIndexStatement>
                          | <dropMaterializedViewStatement>
                          | <dropUserTypeStatement>
                          | <dropFunctionStatement>
                          | <dropAggregateStatement>
                          | <dropTriggerStatement>
                          | <alterTableStatement>
                          | <alterKeyspaceStatement>
                          | <alterUserTypeStatement>
                          ;

<authenticationStatement> ::= <createUserStatement>
                            | <alterUserStatement>
                            | <dropUserStatement>
                            | <listUsersStatement>
                            | <createRoleStatement>
                            | <alterRoleStatement>
                            | <dropRoleStatement>
                            | <listRolesStatement>
                            | <listSuperUsersStatement>
                            ;

<authorizationStatement> ::= <grantStatement>
                           | <grantRoleStatement>
                           | <revokeStatement>
                           | <revokeRoleStatement>
                           | <listPermissionsStatement>
                           ;

# timestamp is included here, since it's also a keyword
<simpleStorageType> ::= typename=( <identifier> | <stringLiteral> | "timestamp" ) ;

<userType> ::= utname=<cfOrKsName> ;

<storageType> ::= ( <simpleStorageType> | <collectionType> | <frozenCollectionType> | <vectorType> | <userType> ) ( <constraintsExpr> )? ( <column_mask> )? ;

<constraintsExpr> ::= "CHECK" <constraint> ( "AND" <constraint> )*
                    ;

<constraint> ::= "NOT" "NULL"
               | <constraintStandaloneFunction>
               | <constraintComparableFunction> <functionArguments> <cmp> <term>
               | <cident> <cmp> <term>
               ;

<constraintComparableFunction> ::= "LENGTH"
                                 | "OCTET_LENGTH"
                                 | "REGEXP"
                                 ;

<constraintStandaloneFunction> ::= "JSON"
                                 ;

<column_mask> ::= "MASKED" "WITH" ( "DEFAULT" | <functionName> <selectionFunctionArguments> );

# Note: autocomplete for frozen collection types does not handle nesting past depth 1 properly,
# but that's a lot of work to fix for little benefit.
<collectionType> ::= "map" "<" <simpleStorageType> "," ( <simpleStorageType> | <userType> ) ">"
                   | "list" "<" ( <simpleStorageType> | <userType> ) ">"
                   | "set" "<" ( <simpleStorageType> | <userType> ) ">"
                   ;

<frozenCollectionType> ::= "frozen" "<" "map"  "<" <storageType> "," <storageType> ">" ">"
                         | "frozen" "<" "list" "<" <storageType> ">" ">"
                         | "frozen" "<" "set"  "<" <storageType> ">" ">"
                         ;

<vectorType> ::= "vector" "<" <storageType> "," <wholenumber> ">" ;

<columnFamilyName> ::= ( ksname=<cfOrKsName> dot="." )? cfname=<cfOrKsName> ;

<materializedViewName> ::= ( ksname=<cfOrKsName> dot="." )? mvname=<cfOrKsName> ;

<userTypeName> ::= ( ksname=<cfOrKsName> dot="." )? utname=<cfOrKsName> ;

<keyspaceName> ::= ksname=<cfOrKsName> ;

<nonSystemKeyspaceName> ::= ksname=<cfOrKsName> ;

<alterableKeyspaceName> ::= ksname=<cfOrKsName> ;

<cfOrKsName> ::= <identifier>
               | <quotedName>
               | <unreservedKeyword>;

<unreservedKeyword> ::= nocomplete=
                        ( "key"
                        | "clustering"
                        # | "count" -- to get count(*) completion, treat count as reserved
                        | "ttl"
                        | "compact"
                        | "storage"
                        | "type"
                        | "values" )
                      ;

<property> ::= [propname]=<cident> propeq="=" [propval]=<propertyValue>
                ;
<propertyValue> ::= propsimpleval=( <stringLiteral>
                                  | <identifier>
                                  | <integer>
                                  | <float>
                                  | <unreservedKeyword> )
                    # we don't use <mapLiteral> here so we can get more targeted
                    # completions:
                    | propsimpleval="{" [propmapkey]=<term> ":" [propmapval]=<term>
                            ( ender="," [propmapkey]=<term> ":" [propmapval]=<term> )*
                      ender="}"
                    ;
<propertyOrOption> ::= <property>
                     | "INDEXES"
                     ;

'''


def prop_equals_completer(ctxt, cass):
    if not working_on_keyspace(ctxt):
        # we know if the thing in the property name position is "compact" or
        # "clustering" that there won't actually be an equals sign, because
        # there are no properties by those names. there are, on the other hand,
        # table properties that start with those keywords which don't have
        # equals signs at all.
        curprop = ctxt.get_binding('propname')[-1].upper()
        if curprop in ('COMPACT', 'CLUSTERING'):
            return ()
    return ['=']


completer_for('property', 'propeq')(prop_equals_completer)


@completer_for('property', 'propname')
def prop_name_completer(ctxt, cass):
    if working_on_keyspace(ctxt):
        return ks_prop_name_completer(ctxt, cass)
    elif 'MATERIALIZED' == ctxt.get_binding('wat', '').upper():
        props = cf_prop_name_completer(ctxt, cass)
        props.remove('default_time_to_live')
        props.remove('gc_grace_seconds')
        return props
    else:
        return cf_prop_name_completer(ctxt, cass)


@completer_for('propertyValue', 'propsimpleval')
def prop_val_completer(ctxt, cass):
    if working_on_keyspace(ctxt):
        return ks_prop_val_completer(ctxt, cass)
    else:
        return cf_prop_val_completer(ctxt, cass)


@completer_for('propertyValue', 'propmapkey')
def prop_val_mapkey_completer(ctxt, cass):
    if working_on_keyspace(ctxt):
        return ks_prop_val_mapkey_completer(ctxt, cass)
    else:
        return cf_prop_val_mapkey_completer(ctxt, cass)


@completer_for('propertyValue', 'propmapval')
def prop_val_mapval_completer(ctxt, cass):
    if working_on_keyspace(ctxt):
        return ks_prop_val_mapval_completer(ctxt, cass)
    else:
        return cf_prop_val_mapval_completer(ctxt, cass)


@completer_for('propertyValue', 'ender')
def prop_val_mapender_completer(ctxt, cass):
    if working_on_keyspace(ctxt):
        return ks_prop_val_mapender_completer(ctxt, cass)
    else:
        return cf_prop_val_mapender_completer(ctxt, cass)


def ks_prop_name_completer(ctxt, cass):
    optsseen = ctxt.get_binding('propname', ())
    if 'replication' not in optsseen:
        return ['replication']
    return ["durable_writes"]


def ks_prop_val_completer(ctxt, cass):
    optname = ctxt.get_binding('propname')[-1]
    if optname == 'durable_writes':
        return ["'true'", "'false'"]
    if optname == 'replication':
        return ["{'class': '"]
    return ()


def ks_prop_val_mapkey_completer(ctxt, cass):
    optname = ctxt.get_binding('propname')[-1]
    if optname != 'replication':
        return ()
    keysseen = list(map(dequote_value, ctxt.get_binding('propmapkey', ())))
    valsseen = list(map(dequote_value, ctxt.get_binding('propmapval', ())))
    for k, v in zip(keysseen, valsseen):
        if k == 'class':
            repclass = v
            break
    else:
        return ["'class'"]
    if repclass == 'SimpleStrategy':
        opts = {'replication_factor'}
    elif repclass == 'NetworkTopologyStrategy':
        return [Hint('<dc_name>')]
    return list(map(escape_value, opts.difference(keysseen)))


def ks_prop_val_mapval_completer(ctxt, cass):
    optname = ctxt.get_binding('propname')[-1]
    if optname != 'replication':
        return ()
    currentkey = dequote_value(ctxt.get_binding('propmapkey')[-1])
    if currentkey == 'class':
        return list(map(escape_value, CqlRuleSet.replication_strategies))
    return [Hint('<term>')]


def ks_prop_val_mapender_completer(ctxt, cass):
    optname = ctxt.get_binding('propname')[-1]
    if optname != 'replication':
        return [',']
    keysseen = list(map(dequote_value, ctxt.get_binding('propmapkey', ())))
    valsseen = list(map(dequote_value, ctxt.get_binding('propmapval', ())))
    for k, v in zip(keysseen, valsseen):
        if k == 'class':
            repclass = v
            break
    else:
        return [',']
    if repclass == 'SimpleStrategy':
        if 'replication_factor' not in keysseen:
            return [',']
    if repclass == 'NetworkTopologyStrategy' and len(keysseen) == 1:
        return [',']
    return ['}']


def cf_prop_name_completer(ctxt, cass):
    return [c[0] for c in (CqlRuleSet.columnfamily_layout_options
                           + CqlRuleSet.columnfamily_layout_map_options)]


def cf_prop_val_completer(ctxt, cass):
    exist_opts = ctxt.get_binding('propname')
    this_opt = exist_opts[-1]
    if this_opt == 'compression':
        return ["{'class': '"]
    if this_opt == 'compaction':
        return ["{'class': '"]
    if this_opt == 'caching':
        return ["{'keys': '"]
    if any(this_opt == opt[0] for opt in CqlRuleSet.obsolete_cf_options):
        return ["'<obsolete_option>'"]
    if this_opt == 'bloom_filter_fp_chance':
        return [Hint('<float_between_0_and_1>')]
    if this_opt in ('min_compaction_threshold', 'max_compaction_threshold',
                    'gc_grace_seconds', 'min_index_interval', 'max_index_interval'):
        return [Hint('<integer>')]
    if this_opt in ('cdc'):
        return [Hint('<true|false>')]
    if this_opt in ('read_repair'):
        return [Hint('<\'none\'|\'blocking\'>')]
    if this_opt == 'allow_auto_snapshot':
        return [Hint('<boolean>')]
    if this_opt == 'incremental_backups':
        return [Hint('<boolean>')]
    return [Hint('<option_value>')]


def cf_prop_val_mapkey_completer(ctxt, cass):
    optname = ctxt.get_binding('propname')[-1]
    for cql3option, _, subopts in CqlRuleSet.columnfamily_layout_map_options:
        if optname == cql3option:
            break
    else:
        return ()
    keysseen = list(map(dequote_value, ctxt.get_binding('propmapkey', ())))
    valsseen = list(map(dequote_value, ctxt.get_binding('propmapval', ())))
    pairsseen = dict(list(zip(keysseen, valsseen)))
    if optname == 'compression':
        return list(map(escape_value, set(subopts).difference(keysseen)))
    if optname == 'caching':
        return list(map(escape_value, set(subopts).difference(keysseen)))
    if optname == 'compaction':
        opts = set(subopts)
        try:
            csc = pairsseen['class']
        except KeyError:
            return ["'class'"]
        csc = csc.split('.')[-1]
        if csc == 'SizeTieredCompactionStrategy':
            opts = opts.union(set(CqlRuleSet.size_tiered_compaction_strategy_options))
        elif csc == 'LeveledCompactionStrategy':
            opts = opts.union(set(CqlRuleSet.leveled_compaction_strategy_options))
        elif csc == 'TimeWindowCompactionStrategy':
            opts = opts.union(set(CqlRuleSet.time_window_compaction_strategy_options))
        elif csc == 'UnifiedCompactionStrategy':
            opts = opts.union(set(CqlRuleSet.unified_compaction_strategy_options))

        return list(map(escape_value, opts))
    return ()


def cf_prop_val_mapval_completer(ctxt, cass):
    opt = ctxt.get_binding('propname')[-1]
    key = dequote_value(ctxt.get_binding('propmapkey')[-1])
    if opt == 'compaction':
        if key == 'class':
            return list(map(escape_value, CqlRuleSet.available_compaction_classes))
        if key == 'provide_overlapping_tombstones':
            return [Hint('<NONE|ROW|CELL>')]
        return [Hint('<option_value>')]
    elif opt == 'compression':
        if key == 'class':
            return list(map(escape_value, CqlRuleSet.available_compression_classes))
        return [Hint('<option_value>')]
    elif opt == 'caching':
        if key == 'rows_per_partition':
            return ["'ALL'", "'NONE'", Hint('#rows_per_partition')]
        elif key == 'keys':
            return ["'ALL'", "'NONE'"]
    return ()


def cf_prop_val_mapender_completer(ctxt, cass):
    return [',', '}']


@completer_for('tokenDefinition', 'token')
def token_word_completer(ctxt, cass):
    return ['TOKEN']


@completer_for('simpleStorageType', 'typename')
def storagetype_completer(ctxt, cass):
    return simple_cql_types


@completer_for('keyspaceName', 'ksname')
def ks_name_completer(ctxt, cass):
    return list(map(maybe_escape_name, cass.get_keyspace_names()))


@completer_for('nonSystemKeyspaceName', 'ksname')
def non_system_ks_name_completer(ctxt, cass):
    ksnames = [n for n in cass.get_keyspace_names() if n not in SYSTEM_KEYSPACES]
    return list(map(maybe_escape_name, ksnames))


@completer_for('alterableKeyspaceName', 'ksname')
def alterable_ks_name_completer(ctxt, cass):
    ksnames = [n for n in cass.get_keyspace_names() if n not in NONALTERBALE_KEYSPACES]
    return list(map(maybe_escape_name, ksnames))


def cf_ks_name_completer(ctxt, cass):
    return [maybe_escape_name(ks) + '.' for ks in cass.get_keyspace_names()]


completer_for('columnFamilyName', 'ksname')(cf_ks_name_completer)
completer_for('materializedViewName', 'ksname')(cf_ks_name_completer)


def cf_ks_dot_completer(ctxt, cass):
    name = dequote_name(ctxt.get_binding('ksname'))
    if name in cass.get_keyspace_names():
        return ['.']
    return []


completer_for('columnFamilyName', 'dot')(cf_ks_dot_completer)
completer_for('materializedViewName', 'dot')(cf_ks_dot_completer)


@completer_for('columnFamilyName', 'cfname')
def cf_name_completer(ctxt, cass):
    ks = ctxt.get_binding('ksname', None)
    if ks is not None:
        ks = dequote_name(ks)
    try:
        cfnames = cass.get_columnfamily_names(ks)
    except Exception:
        if ks is None:
            return ()
        raise
    return list(map(maybe_escape_name, cfnames))


@completer_for('materializedViewName', 'mvname')
def mv_name_completer(ctxt, cass):
    ks = ctxt.get_binding('ksname', None)
    if ks is not None:
        ks = dequote_name(ks)
    try:
        mvnames = cass.get_materialized_view_names(ks)
    except Exception:
        if ks is None:
            return ()
        raise
    return list(map(maybe_escape_name, mvnames))


completer_for('userTypeName', 'ksname')(cf_ks_name_completer)

completer_for('userTypeName', 'dot')(cf_ks_dot_completer)


def ut_name_completer(ctxt, cass):
    ks = ctxt.get_binding('ksname', None)
    if ks is not None:
        ks = dequote_name(ks)
    try:
        utnames = cass.get_usertype_names(ks)
    except Exception:
        if ks is None:
            return ()
        raise
    return list(map(maybe_escape_name, utnames))


completer_for('userTypeName', 'utname')(ut_name_completer)
completer_for('userType', 'utname')(ut_name_completer)


@completer_for('unreservedKeyword', 'nocomplete')
def unreserved_keyword_completer(ctxt, cass):
    # we never want to provide completions through this production;
    # this is always just to allow use of some keywords as column
    # names, CF names, property values, etc.
    return ()


def get_table_meta(ctxt, cass):
    ks = ctxt.get_binding('ksname', None)
    if ks is not None:
        ks = dequote_name(ks)
    cf = dequote_name(ctxt.get_binding('cfname'))
    return cass.get_table_meta(ks, cf)


def get_ut_layout(ctxt, cass):
    ks = ctxt.get_binding('ksname', None)
    if ks is not None:
        ks = dequote_name(ks)
    ut = dequote_name(ctxt.get_binding('utname'))
    return cass.get_usertype_layout(ks, ut)


def working_on_keyspace(ctxt):
    wat = ctxt.get_binding('wat', '').upper()
    if wat in ('KEYSPACE', 'SCHEMA'):
        return True
    return False


syntax_rules += r'''
<useStatement> ::= "USE" <keyspaceName>
                 ;
<selectStatement> ::= "SELECT" ( "JSON" )? <selectClause>
                        "FROM" (cf=<columnFamilyName> | mv=<materializedViewName>)
                          ( "WHERE" <whereClause> )?
                          ( "GROUP" "BY" <groupByClause> ( "," <groupByClause> )* )?
                          ( "ORDER" "BY" <orderByClause> ( "," <orderByClause> )* )?
                          ( "PER" "PARTITION" "LIMIT" perPartitionLimit=<wholenumber> )?
                          ( "LIMIT" limit=<wholenumber> )?
                          ( "ALLOW" "FILTERING" )?
                    ;
<whereClause> ::= <relation> ( "AND" <relation> )*
                ;
<relation> ::= [rel_lhs]=<cident> ( "[" <term> "]" )? ( "=" | "<" | ">" | "<=" | ">=" | "!=" | ( "NOT" )? "CONTAINS" ( "KEY" )? ) (<term> | <operandFunctions>)
             | token="TOKEN" "(" [rel_tokname]=<cident>
                                 ( "," [rel_tokname]=<cident> )*
                             ")" ("=" | "<" | ">" | "<=" | ">=") <tokenDefinition>
             | [rel_lhs]=<cident> (( "NOT" )? "IN" ) "(" <term> ( "," <term> )* ")"
             | [rel_lhs]=<cident> "BETWEEN" <term> "AND" <term>
             | <operandFunctions>
             ;
<selectClause> ::= "DISTINCT"? <selector> ("AS" <cident>)? ("," <selector> ("AS" <cident>)?)*
                 | "*"
                 ;
<udtSubfieldSelection> ::= <identifier> "." <identifier>
                         ;
<selector> ::= [colname]=<cident> ( "[" ( <term> ( ".." <term> "]" )? | <term> ".." ) )?
             | <udtSubfieldSelection>
             | "CAST" "(" <selector> "AS" <storageType> ")"
             | "TTL" "(" [colname]=<cident> ")"
             | "TOKEN" "(" [colname]=<cident> ")"
             | <aggregateMathFunctions>
             | <scalarMathFunctions>
             | <collectionFunctions>
             | <currentTimeFunctions>
             | <maskFunctions>
             | <timeConversionFunctions>
             | <writetimeFunctions>
             | <functionName> <selectionFunctionArguments>
             | <term>
             ;

<selectionFunctionArguments> ::= "(" ( <selector> ( "," <selector> )* )? ")"
                          ;
<orderByClause> ::= [ordercol]=<cident> ( "ASC" | "DESC" )?
                  ;
<groupByClause> ::= [groupcol]=<cident>
                  | <functionName><groupByFunctionArguments>
                  ;
<groupByFunctionArguments> ::= "(" ( <groupByFunctionArgument> ( "," <groupByFunctionArgument> )* )? ")"
                             ;
<groupByFunctionArgument> ::= [groupcol]=<cident>
                            | <term>
                            ;

<aggregateMathFunctions> ::= "COUNT" "(" star=( "*" | "1" ) ")"
             | "AVG" "(" [colname]=<cident> ")"
             | "MIN" "(" [colname]=<cident> ")"
             | "MAX" "(" [colname]=<cident> ")"
             | "SUM" "(" [colname]=<cident> ")"
             ;

<scalarMathFunctions> ::= "ABS" "(" [colname]=<cident> ")"
             | "EXP" "(" [colname]=<cident> ")"
             | "LOG" "(" [colname]=<cident> ")"
             | "LOG10" "(" [colname]=<cident> ")"
             | "ROUND" "(" [colname]=<cident> ")"
             ;

<collectionFunctions> ::= "MAP_KEYS" "(" [colname]=<cident> ")"
             | "MAP_VALUES" "(" [colname]=<cident> ")"
             | "COLLECTION_AVG" "(" [colname]=<cident> ")"
             | "COLLECTION_COUNT" "(" [colname]=<cident> ")"
             | "COLLECTION_MIN" "(" [colname]=<cident> ")"
             | "COLLECTION_MAX" "(" [colname]=<cident> ")"
             | "COLLECTION_SUM" "(" [colname]=<cident> ")"
             ;

<currentTimeFunctions> ::= "CURRENT_DATE()"
             | "CURRENT_TIME()"
             | "CURRENT_TIMESTAMP()"
             | "CURRENT_TIMEUUID()"
             ;

<maskFunctions> ::= "MASK_DEFAULT" "(" [colname]=<cident> ")"
             | "MASK_HASH" "(" [colname]=<cident> ")"
             | "MASK_INNER" "(" [colname]=<cident> "," <wholenumber> "," <wholenumber> ")"
             | "MASK_NULL" "(" [colname]=<cident> ")"
             | "MASK_REPLACE" "(" [colname]=<cident> "," <propertyValue> ")"
             | "MASK_OUTER" "(" [colname]=<cident> "," <wholenumber> "," <wholenumber> ")"
             ;

<timeConversionFunctions> ::= "TO_DATE" "(" [colname]=<cident> ")"
             | "TO_TIMESTAMP" "(" [colname]=<cident> ")"
             | "TO_UNIX_TIMESTAMP" "(" [colname]=<cident> ")"
             ;

<timeuuidFunctions> ::= "MAX_TIMEUUID" "(" [colname]=<cident> ")"
             | "MIN_TIMEUUID" "(" [colname]=<cident> ")"
             ;

<writetimeFunctions> ::= "MAX_WRITETIME" "(" [colname]=<cident> ")"
             | "MIN_WRITETIME" "(" [colname]=<cident> ")"
             | "WRITETIME" "(" [colname]=<cident> ")"
             ;
<operandFunctions> ::= <currentTimeFunctions> | <timeuuidFunctions>
             ;

'''


def udf_name_completer(ctxt, cass):
    ks = ctxt.get_binding('ksname', None)
    if ks is not None:
        ks = dequote_name(ks)
    try:
        udfnames = cass.get_userfunction_names(ks)
    except Exception:
        if ks is None:
            return ()
        raise
    return list(map(maybe_escape_name, udfnames))


def uda_name_completer(ctxt, cass):
    ks = ctxt.get_binding('ksname', None)
    if ks is not None:
        ks = dequote_name(ks)
    try:
        udanames = cass.get_useraggregate_names(ks)
    except Exception:
        if ks is None:
            return ()
        raise
    return list(map(maybe_escape_name, udanames))


def udf_uda_name_completer(ctxt, cass):
    ks = ctxt.get_binding('ksname', None)
    if ks is not None:
        ks = dequote_name(ks)
    try:
        functionnames = cass.get_userfunction_names(ks) + cass.get_useraggregate_names(ks)
    except Exception:
        if ks is None:
            return ()
        raise
    return list(map(maybe_escape_name, functionnames))


def ref_udf_name_completer(ctxt, cass):
    try:
        udanames = cass.get_userfunction_names(None)
    except Exception:
        return ()
    return list(map(maybe_escape_name, udanames))


completer_for('functionAggregateName', 'ksname')(cf_ks_name_completer)
completer_for('functionAggregateName', 'dot')(cf_ks_dot_completer)
completer_for('functionAggregateName', 'functionname')(udf_uda_name_completer)
completer_for('anyFunctionName', 'ksname')(cf_ks_name_completer)
completer_for('anyFunctionName', 'dot')(cf_ks_dot_completer)
completer_for('anyFunctionName', 'udfname')(udf_name_completer)
completer_for('userFunctionName', 'ksname')(cf_ks_name_completer)
completer_for('userFunctionName', 'dot')(cf_ks_dot_completer)
completer_for('userFunctionName', 'udfname')(udf_name_completer)
completer_for('refUserFunctionName', 'udfname')(ref_udf_name_completer)
completer_for('userAggregateName', 'ksname')(cf_ks_name_completer)
completer_for('userAggregateName', 'dot')(cf_ks_dot_completer)
completer_for('userAggregateName', 'udaname')(uda_name_completer)


@completer_for('orderByClause', 'ordercol')
def select_order_column_completer(ctxt, cass):
    prev_order_cols = ctxt.get_binding('ordercol', ())
    keyname = ctxt.get_binding('keyname')
    if keyname is None:
        keyname = ctxt.get_binding('rel_lhs', ())
        if not keyname:
            return [Hint("Can't ORDER BY here: need to specify partition key in WHERE clause")]
    layout = get_table_meta(ctxt, cass)
    order_by_candidates = [col.name for col in layout.clustering_key]
    if len(order_by_candidates) > len(prev_order_cols):
        return [maybe_escape_name(order_by_candidates[len(prev_order_cols)])]
    return [Hint('No more orderable columns here.')]


@completer_for('groupByClause', 'groupcol')
def select_group_column_completer(ctxt, cass):
    prev_group_cols = ctxt.get_binding('groupcol', ())
    layout = get_table_meta(ctxt, cass)
    group_by_candidates = [col.name for col in layout.primary_key]
    if len(group_by_candidates) > len(prev_group_cols):
        return [maybe_escape_name(group_by_candidates[len(prev_group_cols)])]
    return [Hint('No more columns here.')]


@completer_for('relation', 'token')
def relation_token_word_completer(ctxt, cass):
    return ['TOKEN']


@completer_for('relation', 'rel_tokname')
def relation_token_subject_completer(ctxt, cass):
    layout = get_table_meta(ctxt, cass)
    return [key.name for key in layout.partition_key]


@completer_for('relation', 'rel_lhs')
def select_relation_lhs_completer(ctxt, cass):
    layout = get_table_meta(ctxt, cass)
    filterable = set()
    already_filtered_on = list(map(dequote_name, ctxt.get_binding('rel_lhs', ())))
    for num in range(0, len(layout.partition_key)):
        if num == 0 or layout.partition_key[num - 1].name in already_filtered_on:
            filterable.add(layout.partition_key[num].name)
        else:
            break
    for num in range(0, len(layout.clustering_key)):
        if num == 0 or layout.clustering_key[num - 1].name in already_filtered_on:
            filterable.add(layout.clustering_key[num].name)
        else:
            break
    for idx in layout.indexes.values():
        filterable.add(idx.index_options["target"])
    return list(map(maybe_escape_name, filterable))


explain_completion('selector', 'colname')

syntax_rules += r'''
<insertStatement> ::= "INSERT" "INTO" cf=<columnFamilyName>
                      ( ( "(" [colname]=<cident> ( "," [colname]=<cident> )* ")"
                          "VALUES" "(" [newval]=<term> ( valcomma="," [newval]=<term> )* valcomma=")")
                        | ("JSON" <stringLiteral>))
                      ( "IF" "NOT" "EXISTS")?
                      ( "USING" [insertopt]=<usingOption>
                                ( "AND" [insertopt]=<usingOption> )* )?
                    ;
<usingOption> ::= "TIMESTAMP" <wholenumber>
                | "TTL" <wholenumber>
                ;
'''


def regular_column_names(table_meta):
    if not table_meta or not table_meta.columns:
        return []
    regular_columns = list(set(table_meta.columns.keys())
                           - set([key.name for key in table_meta.partition_key])
                           - set([key.name for key in table_meta.clustering_key]))
    return regular_columns


@completer_for('insertStatement', 'colname')
def insert_colname_completer(ctxt, cass):
    layout = get_table_meta(ctxt, cass)
    colnames = set(map(dequote_name, ctxt.get_binding('colname', ())))
    keycols = layout.primary_key
    for k in keycols:
        if k.name not in colnames:
            return [maybe_escape_name(k.name)]
    normalcols = set(regular_column_names(layout)) - colnames
    return list(map(maybe_escape_name, normalcols))


@completer_for('insertStatement', 'newval')
def insert_newval_completer(ctxt, cass):
    layout = get_table_meta(ctxt, cass)
    insertcols = list(map(dequote_name, ctxt.get_binding('colname')))
    valuesdone = ctxt.get_binding('newval', ())
    if len(valuesdone) >= len(insertcols):
        return []
    curcol = insertcols[len(valuesdone)]
    coltype = layout.columns[curcol].cql_type
    if coltype.startswith('map<') or coltype.startswith('set<'):
        return ['{']
    if coltype.startswith('list<') or coltype.startswith('vector<'):
        return ['[']
    if coltype == 'boolean':
        return ['true', 'false']

    return [Hint('<value for %s (%s)>' % (maybe_escape_name(curcol),
                                          coltype))]


@completer_for('insertStatement', 'valcomma')
def insert_valcomma_completer(ctxt, cass):
    numcols = len(ctxt.get_binding('colname', ()))
    numvals = len(ctxt.get_binding('newval', ()))
    if numcols > numvals:
        return [',']
    return [')']


@completer_for('insertStatement', 'insertopt')
def insert_option_completer(ctxt, cass):
    opts = set('TIMESTAMP TTL'.split())
    for opt in ctxt.get_binding('insertopt', ()):
        opts.discard(opt.split()[0])
    return opts


syntax_rules += r'''
<updateStatement> ::= "UPDATE" cf=<columnFamilyName>
                        ( "USING" [updateopt]=<usingOption>
                                  ( "AND" [updateopt]=<usingOption> )* )?
                        "SET" <assignment> ( "," <assignment> )*
                        "WHERE" <whereClause>
                        ( "IF" ( "EXISTS" | <conditions> ))?
                    ;
<assignment> ::= updatecol=<cident>
                    (( "=" update_rhs=( <term> | <cident> )
                                ( counterop=( "+" | "-" ) inc=<wholenumber>
                                | listadder="+" listcol=<cident> )? )
                    | ( indexbracket="[" <term> "]" "=" <term> )
                    | ( udt_field_dot="." udt_field=<identifier> "=" <term> ))
               ;
<conditions> ::=  <condition> ( "AND" <condition> )*
               ;
<condition_op_and_rhs> ::= (("=" | "<" | ">" | "<=" | ">=" | "!=" | "CONTAINS" ( "KEY" )? ) <term>)
                           | ("IN" "(" <term> ( "," <term> )* ")" )
                         ;
<condition> ::= conditioncol=<cident>
                    ( (( indexbracket="[" <term> "]" )
                      |( udt_field_dot="." udt_field=<identifier> )) )?
                    <condition_op_and_rhs>
              ;
'''


@completer_for('updateStatement', 'updateopt')
def update_option_completer(ctxt, cass):
    opts = {'TIMESTAMP', 'TTL'}
    for opt in ctxt.get_binding('updateopt', ()):
        opts.discard(opt.split()[0])
    return opts


@completer_for('assignment', 'updatecol')
def update_col_completer(ctxt, cass):
    layout = get_table_meta(ctxt, cass)
    return list(map(maybe_escape_name, regular_column_names(layout)))


@completer_for('assignment', 'update_rhs')
def update_countername_completer(ctxt, cass):
    layout = get_table_meta(ctxt, cass)
    curcol = dequote_name(ctxt.get_binding('updatecol', ''))
    coltype = layout.columns[curcol].cql_type
    if coltype == 'counter':
        return [maybe_escape_name(curcol)]
    if coltype.startswith('map<') or coltype.startswith('set<'):
        return ['{']
    if coltype.startswith('list<') or coltype.startswith('vector<'):
        return ['[']
    return [Hint('<term (%s)>' % coltype)]


@completer_for('assignment', 'counterop')
def update_counterop_completer(ctxt, cass):
    layout = get_table_meta(ctxt, cass)
    curcol = dequote_name(ctxt.get_binding('updatecol', ''))
    return ['+', '-'] if layout.columns[curcol].cql_type == 'counter' else []


@completer_for('assignment', 'inc')
def update_counter_inc_completer(ctxt, cass):
    layout = get_table_meta(ctxt, cass)
    curcol = dequote_name(ctxt.get_binding('updatecol', ''))
    if layout.columns[curcol].cql_type == 'counter':
        return [Hint('<wholenumber>')]
    return []


@completer_for('assignment', 'listadder')
def update_listadder_completer(ctxt, cass):
    rhs = ctxt.get_binding('update_rhs')
    if rhs.startswith('['):
        return ['+']
    return []


@completer_for('assignment', 'listcol')
def update_listcol_completer(ctxt, cass):
    rhs = ctxt.get_binding('update_rhs')
    if rhs.startswith('['):
        colname = dequote_name(ctxt.get_binding('updatecol'))
        return [maybe_escape_name(colname)]
    return []


@completer_for('assignment', 'indexbracket')
def update_indexbracket_completer(ctxt, cass):
    layout = get_table_meta(ctxt, cass)
    curcol = dequote_name(ctxt.get_binding('updatecol', ''))
    coltype = layout.columns[curcol].cql_type
    if coltype in ('map', 'list'):
        return ['[']
    return []


@completer_for('assignment', 'udt_field_dot')
def update_udt_field_dot_completer(ctxt, cass):
    layout = get_table_meta(ctxt, cass)
    curcol = dequote_name(ctxt.get_binding('updatecol', ''))
    return ["."] if _is_usertype(layout, curcol) else []


@completer_for('assignment', 'udt_field')
def assignment_udt_field_completer(ctxt, cass):
    layout = get_table_meta(ctxt, cass)
    curcol = dequote_name(ctxt.get_binding('updatecol', ''))
    return _usertype_fields(ctxt, cass, layout, curcol)


def _is_usertype(layout, curcol):
    coltype = layout.columns[curcol].cql_type
    return coltype not in simple_cql_types and coltype not in ('map', 'set', 'list', 'vector')


def _usertype_fields(ctxt, cass, layout, curcol):
    if not _is_usertype(layout, curcol):
        return []

    coltype = layout.columns[curcol].cql_type
    ks = ctxt.get_binding('ksname', None)
    if ks is not None:
        ks = dequote_name(ks)
    user_type = cass.get_usertype_layout(ks, coltype)
    return [field_name for (field_name, field_type) in user_type]


@completer_for('condition', 'indexbracket')
def condition_indexbracket_completer(ctxt, cass):
    layout = get_table_meta(ctxt, cass)
    curcol = dequote_name(ctxt.get_binding('conditioncol', ''))
    coltype = layout.columns[curcol].cql_type
    if coltype in ('map', 'list'):
        return ['[']
    return []


@completer_for('condition', 'udt_field_dot')
def condition_udt_field_dot_completer(ctxt, cass):
    layout = get_table_meta(ctxt, cass)
    curcol = dequote_name(ctxt.get_binding('conditioncol', ''))
    return ["."] if _is_usertype(layout, curcol) else []


@completer_for('condition', 'udt_field')
def condition_udt_field_completer(ctxt, cass):
    layout = get_table_meta(ctxt, cass)
    curcol = dequote_name(ctxt.get_binding('conditioncol', ''))
    return _usertype_fields(ctxt, cass, layout, curcol)


syntax_rules += r'''
<deleteStatement> ::= "DELETE" ( <deleteSelector> ( "," <deleteSelector> )* )?
                        "FROM" cf=<columnFamilyName>
                        ( "USING" [delopt]=<deleteOption> )?
                        "WHERE" <whereClause>
                        ( "IF" ( "EXISTS" | <conditions> ) )?
                    ;
<deleteSelector> ::= delcol=<cident>
                     ( ( "[" <term> "]" )
                     | ( "." <identifier> ) )?
                   ;
<deleteOption> ::= "TIMESTAMP" <wholenumber>
                 ;
'''


@completer_for('deleteStatement', 'delopt')
def delete_opt_completer(ctxt, cass):
    opts = set('TIMESTAMP'.split())
    for opt in ctxt.get_binding('delopt', ()):
        opts.discard(opt.split()[0])
    return opts


@completer_for('deleteSelector', 'delcol')
def delete_delcol_completer(ctxt, cass):
    layout = get_table_meta(ctxt, cass)
    return list(map(maybe_escape_name, regular_column_names(layout)))


syntax_rules += r'''
<batchStatement> ::= "BEGIN" ( "UNLOGGED" | "COUNTER" )? "BATCH"
                        ( "USING" [batchopt]=<usingOption>
                                  ( "AND" [batchopt]=<usingOption> )* )?
                        [batchstmt]=<batchStatementMember> ";"?
                            ( [batchstmt]=<batchStatementMember> ";"? )*
                     "APPLY" "BATCH"
                   ;
<batchStatementMember> ::= <insertStatement>
                         | <updateStatement>
                         | <deleteStatement>
                         ;
'''


@completer_for('batchStatement', 'batchopt')
def batch_opt_completer(ctxt, cass):
    opts = set('TIMESTAMP'.split())
    for opt in ctxt.get_binding('batchopt', ()):
        opts.discard(opt.split()[0])
    return opts


syntax_rules += r'''
<truncateStatement> ::= "TRUNCATE" ("COLUMNFAMILY" | "TABLE")? cf=<columnFamilyName>
                      ;
'''

syntax_rules += r'''
<createKeyspaceStatement> ::= "CREATE" wat=( "KEYSPACE" | "SCHEMA" ) ("IF" "NOT" "EXISTS")?  ksname=<cfOrKsName>
                                "WITH" <property> ( "AND" <property> )*
                            ;
'''


@completer_for('createKeyspaceStatement', 'wat')
def create_ks_wat_completer(ctxt, cass):
    # would prefer to get rid of the "schema" nomenclature in cql3
    if ctxt.get_binding('partial', '') == '':
        return ['KEYSPACE']
    return ['KEYSPACE', 'SCHEMA']


syntax_rules += r'''
<createColumnFamilyStatement> ::= "CREATE" wat=( "COLUMNFAMILY" | "TABLE" ) ("IF" "NOT" "EXISTS")?
                                    ( ks=<nonSystemKeyspaceName> dot="." )? cf=<cfOrKsName>
                                    "(" ( <singleKeyCfSpec> | <compositeKeyCfSpec> ) ")"
                                   ( "WITH" <cfamProperty> ( "AND" <cfamProperty> )* )?
                                ;

<cfamProperty> ::= <property>
                 | "COMPACT" "STORAGE" "CDC"
                 | "CLUSTERING" "ORDER" "BY" "(" <cfamOrdering>
                                                 ( "," <cfamOrdering> )* ")"
                 ;

<cfamOrdering> ::= [ordercol]=<cident> ( "ASC" | "DESC" )
                 ;

<singleKeyCfSpec> ::= [newcolname]=<cident> <storageType> "PRIMARY" "KEY"
                      ( "," [newcolname]=<cident> <storageType> )*
                    ;

<compositeKeyCfSpec> ::= [newcolname]=<cident> <storageType>
                         "," [newcolname]=<cident> <storageType> ( "static" )?
                         ( "," [newcolname]=<cident> <storageType> ( "static" )? )*
                         "," "PRIMARY" k="KEY" p="(" ( partkey=<pkDef> | [pkey]=<cident> )
                                                     ( c="," [pkey]=<cident> )* ")"
                       ;

<pkDef> ::= "(" [ptkey]=<cident> "," [ptkey]=<cident>
                               ( "," [ptkey]=<cident> )* ")"
          ;
'''


@completer_for('cfamOrdering', 'ordercol')
def create_cf_clustering_order_colname_completer(ctxt, cass):
    colnames = list(map(dequote_name, ctxt.get_binding('newcolname', ())))
    # Definitely some of these aren't valid for ordering, but I'm not sure
    # precisely which are. This is good enough for now
    return colnames


@completer_for('createColumnFamilyStatement', 'wat')
def create_cf_wat_completer(ctxt, cass):
    # would prefer to get rid of the "columnfamily" nomenclature in cql3
    if ctxt.get_binding('partial', '') == '':
        return ['TABLE']
    return ['TABLE', 'COLUMNFAMILY']


explain_completion('createColumnFamilyStatement', 'cf', '<new_table_name>')
explain_completion('compositeKeyCfSpec', 'newcolname', '<new_column_name>')


@completer_for('createColumnFamilyStatement', 'dot')
def create_cf_ks_dot_completer(ctxt, cass):
    ks = dequote_name(ctxt.get_binding('ks'))
    if ks in cass.get_keyspace_names():
        return ['.']
    return []


@completer_for('pkDef', 'ptkey')
def create_cf_pkdef_declaration_completer(ctxt, cass):
    cols_declared = ctxt.get_binding('newcolname')
    pieces_already = ctxt.get_binding('ptkey', ())
    pieces_already = list(map(dequote_name, pieces_already))
    while cols_declared[0] in pieces_already:
        cols_declared = cols_declared[1:]
        if len(cols_declared) < 2:
            return ()
    return [maybe_escape_name(cols_declared[0])]


@completer_for('compositeKeyCfSpec', 'pkey')
def create_cf_composite_key_declaration_completer(ctxt, cass):
    cols_declared = ctxt.get_binding('newcolname')
    pieces_already = ctxt.get_binding('ptkey', ()) + ctxt.get_binding('pkey', ())
    pieces_already = list(map(dequote_name, pieces_already))
    while cols_declared[0] in pieces_already:
        cols_declared = cols_declared[1:]
        if len(cols_declared) < 2:
            return ()
    return [maybe_escape_name(cols_declared[0])]


@completer_for('compositeKeyCfSpec', 'k')
def create_cf_composite_primary_key_keyword_completer(ctxt, cass):
    return ['KEY (']


@completer_for('compositeKeyCfSpec', 'p')
def create_cf_composite_primary_key_paren_completer(ctxt, cass):
    return ['(']


@completer_for('compositeKeyCfSpec', 'c')
def create_cf_composite_primary_key_comma_completer(ctxt, cass):
    cols_declared = ctxt.get_binding('newcolname')
    pieces_already = ctxt.get_binding('pkey', ())
    if len(pieces_already) >= len(cols_declared) - 1:
        return ()
    return [',']


syntax_rules += r'''
<copyTableStatement> ::= "CREATE" wat=("COLUMNFAMILY" | "TABLE" ) ("IF" "NOT" "EXISTS")?
                                ( tks=<nonSystemKeyspaceName> dot="." )? tcf=<cfOrKsName>
                                "LIKE" ( sks=<nonSystemKeyspaceName> dot="." )? scf=<cfOrKsName>
                                ( "WITH" <propertyOrOption> ( "AND" <propertyOrOption> )* )?
                            ;
'''


@completer_for('copyTableStatement', 'wat')
def create_tb_wat_completer(ctxt, cass):
    # would prefer to get rid of the "schema" nomenclature in cql3
    if ctxt.get_binding('partial', '') == '':
        return ['TABLE']
    return ['COLUMNFAMILY', 'TABLE']


explain_completion('copyTableStatement', 'tcf', '<new_table_name>')
explain_completion('copyTableStatement', 'scf', '<old_table_name>')


syntax_rules += r'''

<idxName> ::= <identifier>
            | <quotedName>
            | <unreservedKeyword>;

<createIndexStatement> ::= "CREATE" "CUSTOM"? "INDEX" ("IF" "NOT" "EXISTS")? indexname=<idxName>? "ON"
                               cf=<columnFamilyName> "(" (
                                   col=<cident> |
                                   "keys(" col=<cident> ")" |
                                   "full(" col=<cident> ")"
                               ) ")"
                               ( "USING" <stringLiteral> ( "WITH" "OPTIONS" "=" <mapLiteral> )? )?
                         ;


<colList> ::= "(" <cident> ( "," <cident> )* ")"
          ;

<createMaterializedViewStatement> ::= "CREATE" wat="MATERIALIZED" "VIEW" ("IF" "NOT" "EXISTS")? viewname=<materializedViewName>?
                                      "AS" "SELECT" <selectClause>
                                      "FROM" cf=<columnFamilyName>
                                      "WHERE" <cident> "IS" "NOT" "NULL" ( "AND" <cident> "IS" "NOT" "NULL")*
                                      "PRIMARY" "KEY" (<colList> | ( "(" <colList> ( "," <cident> )* ")" ))
                                      ( "WITH" <cfamProperty> ( "AND" <cfamProperty> )* )?
                                    ;

<createUserTypeStatement> ::= "CREATE" "TYPE" ("IF" "NOT" "EXISTS")? ( ks=<nonSystemKeyspaceName> dot="." )? typename=<cfOrKsName> "(" newcol=<cident> <storageType>
                                ( "," [newcolname]=<cident> <storageType> )*
                            ")"
                         ;

<createFunctionStatement> ::= "CREATE" ("OR" "REPLACE")? "FUNCTION"
                            ("IF" "NOT" "EXISTS")?
                            <userFunctionName>
                            ( "(" ( newcol=<cident> <storageType>
                              ( "," [newcolname]=<cident> <storageType> )* )?
                            ")" )?
                            ("RETURNS" "NULL" | "CALLED") "ON" "NULL" "INPUT"
                            "RETURNS" <storageType>
                            "LANGUAGE" <cident> "AS" <stringLiteral>
                         ;

<createAggregateStatement> ::= "CREATE" ("OR" "REPLACE")? "AGGREGATE"
                            ("IF" "NOT" "EXISTS")?
                            <userAggregateName>
                            ( "("
                                 ( <storageType> ( "," <storageType> )* )?
                              ")" )?
                            "SFUNC" <refUserFunctionName>
                            "STYPE" <storageType>
                            ( "FINALFUNC" <refUserFunctionName> )?
                            ( "INITCOND" <term> )?
                         ;

'''

explain_completion('createIndexStatement', 'indexname', '<new_index_name>')
explain_completion('createMaterializedViewStatement', 'viewname', '<new_view_name>')
explain_completion('createUserTypeStatement', 'typename', '<new_type_name>')
explain_completion('createUserTypeStatement', 'newcol', '<new_field_name>')


@completer_for('createIndexStatement', 'col')
def create_index_col_completer(ctxt, cass):
    """ Return the columns for which an index doesn't exist yet. """
    layout = get_table_meta(ctxt, cass)
    idx_targets = [idx.index_options["target"] for idx in layout.indexes.values()]
    colnames = [cd.name for cd in list(layout.columns.values()) if cd.name not in idx_targets]
    return list(map(maybe_escape_name, colnames))


syntax_rules += r'''
<dropKeyspaceStatement> ::= "DROP" "KEYSPACE" ("IF" "EXISTS")? ksname=<nonSystemKeyspaceName>
                          ;

<dropColumnFamilyStatement> ::= "DROP" ( "COLUMNFAMILY" | "TABLE" ) ("IF" "EXISTS")? cf=<columnFamilyName>
                              ;

<indexName> ::= ( ksname=<idxOrKsName> dot="." )? idxname=<idxOrKsName> ;

<idxOrKsName> ::= <identifier>
               | <quotedName>
               | <unreservedKeyword>;

<dropIndexStatement> ::= "DROP" "INDEX" ("IF" "EXISTS")? idx=<indexName>
                       ;

<dropMaterializedViewStatement> ::= "DROP" "MATERIALIZED" "VIEW" ("IF" "EXISTS")? mv=<materializedViewName>
                                  ;

<dropUserTypeStatement> ::= "DROP" "TYPE" ( "IF" "EXISTS" )? ut=<userTypeName>
                          ;

<dropFunctionStatement> ::= "DROP" "FUNCTION" ( "IF" "EXISTS" )? <userFunctionName>
                          ;

<dropAggregateStatement> ::= "DROP" "AGGREGATE" ( "IF" "EXISTS" )? <userAggregateName>
                          ;

'''


@completer_for('indexName', 'ksname')
def idx_ks_name_completer(ctxt, cass):
    return [maybe_escape_name(ks) + '.' for ks in cass.get_keyspace_names()]


@completer_for('indexName', 'dot')
def idx_ks_dot_completer(ctxt, cass):
    name = dequote_name(ctxt.get_binding('ksname'))
    if name in cass.get_keyspace_names():
        return ['.']
    return []


@completer_for('indexName', 'idxname')
def idx_ks_idx_name_completer(ctxt, cass):
    ks = ctxt.get_binding('ksname', None)
    if ks is not None:
        ks = dequote_name(ks)
    try:
        idxnames = cass.get_index_names(ks)
    except Exception:
        if ks is None:
            return ()
        raise
    return list(map(maybe_escape_name, idxnames))


syntax_rules += r'''
<alterTableStatement> ::= "ALTER" wat=( "COLUMNFAMILY" | "TABLE" ) ("IF" "EXISTS")? cf=<columnFamilyName>
                               <alterInstructions>
                        ;
<alterInstructions> ::= "ADD" ("IF" "NOT" "EXISTS")? newcol=<cident> <storageType> ("static")?
                      | "DROP" ("IF" "EXISTS")? existcol=<cident>
                      | "WITH" <cfamProperty> ( "AND" <cfamProperty> )*
                      | "RENAME" ("IF" "EXISTS")? existcol=<cident> "TO" newcol=<cident>
                         ( "AND" existcol=<cident> "TO" newcol=<cident> )*
                      | "ALTER" ("IF" "EXISTS")? existcol=<cident> ( <constraintsExpr> | <column_mask> | "DROP" ( "CHECK" | "MASKED" ) )
                      ;

<alterUserTypeStatement> ::= "ALTER" "TYPE" ("IF" "EXISTS")? ut=<userTypeName>
                               <alterTypeInstructions>
                             ;
<alterTypeInstructions> ::= "ADD" ("IF" "NOT" "EXISTS")? newcol=<cident> <storageType>
                           | "RENAME" ("IF" "EXISTS")? existcol=<cident> "TO" newcol=<cident>
                              ( "AND" existcol=<cident> "TO" newcol=<cident> )*
                           ;
'''


@completer_for('alterInstructions', 'existcol')
def alter_table_col_completer(ctxt, cass):
    layout = get_table_meta(ctxt, cass)
    cols = [str(md) for md in layout.columns]
    return list(map(maybe_escape_name, cols))


@completer_for('alterTypeInstructions', 'existcol')
def alter_type_field_completer(ctxt, cass):
    layout = get_ut_layout(ctxt, cass)
    fields = [atuple[0] for atuple in layout]
    return list(map(maybe_escape_name, fields))


explain_completion('alterInstructions', 'newcol', '<new_column_name>')
explain_completion('alterTypeInstructions', 'newcol', '<new_field_name>')


syntax_rules += r'''
<alterKeyspaceStatement> ::= "ALTER" wat=( "KEYSPACE" | "SCHEMA" ) ("IF" "EXISTS")? ks=<alterableKeyspaceName>
                                 "WITH" <property> ( "AND" <property> )*
                           ;
'''

syntax_rules += r'''
<username> ::= name=( <identifier> | <stringLiteral> )
             ;

<createUserStatement> ::= "CREATE" "USER" ( "IF" "NOT" "EXISTS" )? <username>
                              ( ("WITH" ("HASHED")? "PASSWORD" <stringLiteral>) | ("WITH" "GENERATED" "PASSWORD") )?
                              ( "SUPERUSER" | "NOSUPERUSER" )?
                        ;

<alterUserStatement> ::= "ALTER" "USER" ("IF" "EXISTS")? <username>
                              ( ("WITH" "PASSWORD" <stringLiteral>) | ("WITH" "GENERATED" "PASSWORD") )?
                              ( "SUPERUSER" | "NOSUPERUSER" )?
                       ;

<dropUserStatement> ::= "DROP" "USER" ( "IF" "EXISTS" )? <username>
                      ;

<listUsersStatement> ::= "LIST" "USERS"
                       ;
'''

syntax_rules += r'''
<rolename> ::= role=( <identifier>
             | <quotedName>
             | <unreservedKeyword> )
             ;

<createRoleStatement> ::= "CREATE" "ROLE" ("IF" "NOT" "EXISTS")? <rolename>
                              ( "WITH" <roleProperty> ("AND" <roleProperty>)*)?
                        ;

<alterRoleStatement> ::= "ALTER" "ROLE" ("IF" "EXISTS")? <rolename>
                              ( "WITH" <roleProperty> ("AND" <roleProperty>)*)
                       ;

<roleProperty> ::= (("HASHED")? "PASSWORD") "=" <stringLiteral>
                 | "GENERATED" "PASSWORD"
                 | "OPTIONS" "=" <mapLiteral>
                 | "SUPERUSER" "=" <boolean>
                 | "LOGIN" "=" <boolean>
                 | "ACCESS" "TO" "DATACENTERS" <setLiteral>
                 | "ACCESS" "TO" "ALL" "DATACENTERS"
                 | "ACCESS" "FROM" "CIDRS" <setLiteral>
                 | "ACCESS" "FROM" "ALL" "CIDRS"
                 ;

<dropRoleStatement> ::= "DROP" "ROLE" ("IF" "EXISTS")? <rolename>
                      ;

<grantRoleStatement> ::= "GRANT" <rolename> "TO" <rolename>
                       ;

<revokeRoleStatement> ::= "REVOKE" <rolename> "FROM" <rolename>
                        ;

<listRolesStatement> ::= "LIST" "ROLES"
                              ( "OF" <rolename> )? "NORECURSIVE"?
                       ;

<listSuperUsersStatement> ::= "LIST" "SUPERUSERS"
                       ;
'''

syntax_rules += r'''
<grantStatement> ::= "GRANT" <permissionExpr> "ON" <resource> "TO" <rolename>
                   ;

<revokeStatement> ::= "REVOKE" <permissionExpr> "ON" <resource> "FROM" <rolename>
                    ;

<listPermissionsStatement> ::= "LIST" <permissionExpr>
                                    ( "ON" <resource> )? ( "OF" <rolename> )? "NORECURSIVE"?
                             ;

<permission> ::= "AUTHORIZE"
               | "CREATE"
               | "ALTER"
               | "DROP"
               | "SELECT"
               | "MODIFY"
               | "DESCRIBE"
               | "EXECUTE"
               | "UNMASK"
               | "SELECT_MASKED"
               ;

<permissionExpr> ::= ( [newpermission]=<permission> "PERMISSION"? ( "," [newpermission]=<permission> "PERMISSION"? )* )
                   | ( "ALL" "PERMISSIONS"? )
                   ;

<resource> ::= <dataResource>
             | <roleResource>
             | <functionResource>
             | <jmxResource>
             ;

<dataResource> ::= ( "ALL" "KEYSPACES" )
                 | ( "KEYSPACE" <keyspaceName> )
                 | ( "ALL" "TABLES" "IN" "KEYSPACE" <keyspaceName> )
                 | ( "TABLE"? <columnFamilyName> )
                 ;

<roleResource> ::= ("ALL" "ROLES")
                 | ("ROLE" <rolename>)
                 ;

<functionResource> ::= ( "ALL" "FUNCTIONS" ("IN KEYSPACE" <keyspaceName>)? )
                     | ( "FUNCTION" <functionAggregateName>
                           ( "(" ( newcol=<cident> <storageType>
                             ( "," [newcolname]=<cident> <storageType> )* )?
                           ")" )
                       )
                     ;

<jmxResource> ::= ( "ALL" "MBEANS")
                | ( ( "MBEAN" | "MBEANS" ) <stringLiteral> )
                ;

'''


@completer_for('permissionExpr', 'newpermission')
def permission_completer(ctxt, _):
    new_permissions = set([permission.upper() for permission in ctxt.get_binding('newpermission')])
    all_permissions = set([permission.arg for permission in ctxt.ruleset['permission'].arg])
    suggestions = all_permissions - new_permissions
    if len(suggestions) == 0:
        return [Hint('No more permissions here.')]
    return suggestions


@completer_for('username', 'name')
def username_name_completer(ctxt, cass):
    # disable completion for CREATE USER.
    if ctxt.matched[0][1].upper() == 'CREATE':
        return [Hint('<username>')]

    session = cass.session
    return map(maybe_escape_name, [row['name'] for row in session.execute("LIST USERS")])


@completer_for('rolename', 'role')
def rolename_completer(ctxt, cass):
    # disable completion for CREATE ROLE.
    if ctxt.matched[0][1].upper() == 'CREATE':
        return [Hint('<rolename>')]

    session = cass.session
    return map(maybe_escape_name, [row['role'] for row in session.execute("LIST ROLES")])


syntax_rules += r'''
<createTriggerStatement> ::= "CREATE" "TRIGGER" ( "IF" "NOT" "EXISTS" )? <cident>
                               "ON" cf=<columnFamilyName> "USING" class=<stringLiteral>
                           ;
<dropTriggerStatement> ::= "DROP" "TRIGGER" ( "IF" "EXISTS" )? triggername=<cident>
                             "ON" cf=<columnFamilyName>
                         ;
'''
explain_completion('createTriggerStatement', 'class', '\'fully qualified class name\'')


def get_trigger_names(ctxt, cass):
    ks = ctxt.get_binding('ksname', None)
    if ks is not None:
        ks = dequote_name(ks)
    return cass.get_trigger_names(ks)


@completer_for('dropTriggerStatement', 'triggername')
def drop_trigger_completer(ctxt, cass):
    names = get_trigger_names(ctxt, cass)
    return list(map(maybe_escape_name, names))


# END SYNTAX/COMPLETION RULE DEFINITIONS

CqlRuleSet.append_rules(syntax_rules)
