import psycopg2
import psycopg2.extras
from psycopg2.extensions import QuotedString
import os
import aiopg
import json
import math
import re
import time
from services.utils import logging, DBType
from typing import List, Tuple

from .db_utils import DBResponse, DBPagination, aiopg_exception_handling, \
    get_db_ts_epoch_str, translate_run_key, translate_task_key, new_heartbeat_ts
from .models import FlowRow, RunRow, StepRow, TaskRow, MetadataRow, ArtifactRow
from services.utils import DBConfiguration, USE_SEPARATE_READER_POOL

from services.data.service_configs import max_connection_retires, \
    connection_retry_wait_time_seconds

AIOPG_ECHO = os.environ.get("AIOPG_ECHO", 0) == "1"

WAIT_TIME = 10

# Create database triggers automatically, disabled by default
# Enable with env variable `DB_TRIGGER_CREATE=1`
DB_TRIGGER_CREATE = os.environ.get("DB_TRIGGER_CREATE", 0) == "1"

# Configure DB Table names. Custom names can be supplied through environment variables,
# in case the deployment differs from the default naming scheme from the supplied migrations.
FLOW_TABLE_NAME = os.environ.get("DB_TABLE_NAME_FLOWS", "flows_v3")
RUN_TABLE_NAME = os.environ.get("DB_TABLE_NAME_RUNS", "runs_v3")
STEP_TABLE_NAME = os.environ.get("DB_TABLE_NAME_STEPS", "steps_v3")
TASK_TABLE_NAME = os.environ.get("DB_TABLE_NAME_TASKS", "tasks_v3")
METADATA_TABLE_NAME = os.environ.get("DB_TABLE_NAME_METADATA", "metadata_v3")
ARTIFACT_TABLE_NAME = os.environ.get("DB_TABLE_NAME_ARTIFACT", "artifact_v3")
DB_SCHEMA_NAME = os.environ.get("DB_SCHEMA_NAME", "public")

operator_match = re.compile('([^:]*):([=><]+)$')

# use a ddmmyyy timestamp as the version for triggers
TRIGGER_VERSION = "05092024"
TRIGGER_NAME_PREFIX = "notify_ui"


class _AsyncPostgresDB(object):
    connection = None
    flow_table_postgres = None
    run_table_postgres = None
    step_table_postgres = None
    task_table_postgres = None
    artifact_table_postgres = None
    metadata_table_postgres = None

    pool = None
    reader_pool = None
    db_conf: DBConfiguration = None

    def __init__(self, name='global'):
        self.name = name
        self.logger = logging.getLogger("AsyncPostgresDB:{name}".format(name=self.name))

        tables = []
        self.flow_table_postgres = AsyncFlowTablePostgres(self)
        self.run_table_postgres = AsyncRunTablePostgres(self)
        self.step_table_postgres = AsyncStepTablePostgres(self)
        self.task_table_postgres = AsyncTaskTablePostgres(self)
        self.artifact_table_postgres = AsyncArtifactTablePostgres(self)
        self.metadata_table_postgres = AsyncMetadataTablePostgres(self)
        tables.append(self.flow_table_postgres)
        tables.append(self.run_table_postgres)
        tables.append(self.step_table_postgres)
        tables.append(self.task_table_postgres)
        tables.append(self.artifact_table_postgres)
        tables.append(self.metadata_table_postgres)
        self.tables = tables

    async def _init(self, db_conf: DBConfiguration, create_triggers=DB_TRIGGER_CREATE):
        # todo make poolsize min and max configurable as well as timeout
        # todo add retry and better error message
        retries = max_connection_retires
        for i in range(retries):
            try:
                self.pool = await aiopg.create_pool(
                    db_conf.get_dsn(),
                    minsize=db_conf.pool_min,
                    maxsize=db_conf.pool_max,
                    timeout=db_conf.timeout,
                    pool_recycle=10 * db_conf.timeout,
                    echo=AIOPG_ECHO)

                self.reader_pool = await aiopg.create_pool(
                    db_conf.get_dsn(type=DBType.READER),
                    minsize=db_conf.pool_min,
                    maxsize=db_conf.pool_max,
                    timeout=db_conf.timeout,
                    pool_recycle=10 * db_conf.timeout,
                    echo=AIOPG_ECHO) if USE_SEPARATE_READER_POOL else self.pool

                for table in self.tables:
                    await table._init(create_triggers=create_triggers)

                if USE_SEPARATE_READER_POOL:
                    self.logger.info(
                        "Writer Connection established.\n"
                        "   Pool min: {pool_min} max: {pool_max}\n".format(
                            pool_min=self.pool.minsize,
                            pool_max=self.pool.maxsize))

                    self.logger.info(
                        "Reader Connection established.\n"
                        "   Pool min: {pool_min} max: {pool_max}\n".format(
                            pool_min=self.reader_pool.minsize,
                            pool_max=self.reader_pool.maxsize))
                else:
                    self.logger.info(
                        "Connection established.\n"
                        "   Pool min: {pool_min} max: {pool_max}\n".format(
                            pool_min=self.pool.minsize,
                            pool_max=self.pool.maxsize))

                break  # Break the retry loop
            except Exception as e:
                self.logger.exception("Exception occurred")
                if retries - i <= 1:
                    raise e
                time.sleep(connection_retry_wait_time_seconds)

    def get_table_by_name(self, table_name: str):
        for table in self.tables:
            if table.table_name == table_name:
                return table
        return None

    async def get_run_ids(self, flow_id: str, run_id: str):
        run = await self.run_table_postgres.get_run(flow_id, run_id,
                                                    expanded=True)
        return run.body['run_number'], run.body['run_id']

    async def get_task_ids(self, flow_id: str, run_id: str,
                           step_name: str, task_name: str):

        task = await self.task_table_postgres.get_task(flow_id, run_id,
                                                       step_name, task_name,
                                                       expanded=True)
        return task.body['task_id'], task.body['task_name']


class AsyncPostgresDB(object):
    __instance = None

    @staticmethod
    def get_instance():
        return AsyncPostgresDB()

    def __init__(self):
        if not AsyncPostgresDB.__instance:
            AsyncPostgresDB.__instance = _AsyncPostgresDB()

    def __getattribute__(self, name):
        return getattr(AsyncPostgresDB.__instance, name)


class AsyncPostgresTable(object):
    db = None
    table_name = None
    schema_version = 1
    keys: List[str] = []
    primary_keys: List[str] = None
    trigger_keys: List[str] = None
    trigger_operations: List[str] = ["INSERT", "UPDATE", "DELETE"]
    trigger_conditions: List[str] = None
    ordering: List[str] = None
    joins: List[str] = None
    select_columns: List[str] = keys
    join_columns: List[str] = None
    _insert_command = None
    _filters = None
    _base_query = "SELECT {0} from"
    _row_type = None

    def __init__(self, db: _AsyncPostgresDB = None):
        self.db = db
        if self.table_name is None:
            raise NotImplementedError(
                "need to specify table name")

    async def _init(self, create_triggers: bool):
        if create_triggers:
            self.db.logger.info(
                "Setting up notify trigger for {table_name}\n   Keys: {keys}".format(
                    table_name=self.table_name, keys=self.trigger_keys))
            await PostgresUtils.cleanup_triggers(db=self.db, table_name=self.table_name)
            if self.trigger_keys and self.trigger_operations:
                await PostgresUtils.setup_trigger_notify(
                    db=self.db,
                    table_name=self.table_name,
                    keys=self.trigger_keys,
                    operations=self.trigger_operations,
                    conditions=self.trigger_conditions
                )

    async def get_records(self, filter_dict={}, fetch_single=False,
                          ordering: List[str] = None, limit: int = 0, expanded=False,
                          cur: aiopg.Cursor = None) -> DBResponse:
        conditions = []
        values = []
        for col_name, col_val in filter_dict.items():
            conditions.append("{} = %s".format(col_name))
            values.append(col_val)

        response, _ = await self.find_records(
            conditions=conditions, values=values, fetch_single=fetch_single,
            order=ordering, limit=limit, expanded=expanded, cur=cur
        )
        return response

    async def find_records(self, conditions: List[str] = None, values=[], fetch_single=False,
                           limit: int = 0, offset: int = 0, order: List[str] = None, expanded=False,
                           enable_joins=False, cur: aiopg.Cursor = None) -> Tuple[DBResponse, DBPagination]:
        sql_template = """
        SELECT * FROM (
            SELECT
                {keys}
            FROM {table_name}
            {joins}
        ) T
        {where}
        {order_by}
        {limit}
        {offset}
        """

        select_sql = sql_template.format(
            keys=",".join(
                self.select_columns + (self.join_columns if enable_joins and self.join_columns else [])),
            table_name=self.table_name,
            joins=" ".join(self.joins) if enable_joins and self.joins is not None else "",
            where="WHERE {}".format(" AND ".join(conditions)) if conditions else "",
            order_by="ORDER BY {}".format(", ".join(order)) if order else "",
            limit="LIMIT {}".format(limit) if limit else "",
            offset="OFFSET {}".format(offset) if offset else ""
        ).strip()

        return await self.execute_sql(select_sql=select_sql, values=values, fetch_single=fetch_single,
                                      expanded=expanded, limit=limit, offset=offset, cur=cur)

    async def execute_sql(self, select_sql: str, values=[], fetch_single=False,
                          expanded=False, limit: int = 0, offset: int = 0,
                          cur: aiopg.Cursor = None, serialize: bool = True) -> Tuple[DBResponse, DBPagination]:
        async def _execute_on_cursor(_cur):
            await _cur.execute(select_sql, values)

            rows = []
            records = await _cur.fetchall()
            if serialize:
                for record in records:
                    # pylint-initial-ignore: Lack of __init__ makes this too hard for pylint
                    # pylint: disable=not-callable
                    row = self._row_type(**record)
                    rows.append(row.serialize(expanded))
            else:
                rows = records

            count = len(rows)

            # Will raise IndexError in case fetch_single=True and there's no results
            body = rows[0] if fetch_single else rows
            pagination = DBPagination(
                limit=limit,
                offset=offset,
                count=count,
                page=math.floor(int(offset) / max(int(limit), 1)) + 1,
            )
            return body, pagination

        try:
            if cur:
                # if we are using the passed in cursor, we allow any errors to be managed by cursor owner
                body, pagination = await _execute_on_cursor(cur)
                return DBResponse(response_code=200, body=body), pagination
            else:
                db_pool = self.db.reader_pool if USE_SEPARATE_READER_POOL else self.db.pool
                with (await db_pool.cursor(
                        cursor_factory=psycopg2.extras.DictCursor
                )) as cur:
                    body, pagination = await _execute_on_cursor(cur)
                    cur.close()
                    return DBResponse(response_code=200, body=body), pagination
        except IndexError as error:
            return aiopg_exception_handling(error), None
        except (Exception, psycopg2.DatabaseError) as error:
            self.db.logger.exception("Exception occurred")
            return aiopg_exception_handling(error), None

    async def create_record(self, record_dict):
        # note: need to maintain order
        cols = []
        values = []
        for col_name, col_val in record_dict.items():
            cols.append(col_name)
            values.append(col_val)

        # add create ts
        cols.append("ts_epoch")
        values.append(get_db_ts_epoch_str())

        str_format = []
        for _ in cols:
            str_format.append("%s")

        seperator = ", "

        insert_sql = """
                    INSERT INTO {0}({1}) VALUES({2})
                    RETURNING *
                    """.format(
            self.table_name, seperator.join(cols), seperator.join(str_format)
        )

        try:
            response_body = {}
            with (
                await self.db.pool.cursor(
                    cursor_factory=psycopg2.extras.DictCursor
                )
            ) as cur:

                await cur.execute(insert_sql, tuple(values))
                records = await cur.fetchall()
                record = records[0]
                filtered_record = {}
                for key, value in record.items():
                    if key in self.keys:
                        filtered_record[key] = value
                response_body = self._row_type(**filtered_record).serialize()  # pylint: disable=not-callable
                # todo make sure connection is closed even with error
                cur.close()
            return DBResponse(response_code=200, body=response_body)
        except (Exception, psycopg2.DatabaseError) as error:
            self.db.logger.exception("Exception occurred")
            return aiopg_exception_handling(error)

    async def run_in_transaction_with_serializable_isolation_level(self, fun):
        try:
            with (
                    await self.db.pool.cursor(
                        cursor_factory=psycopg2.extras.DictCursor,
                    )
            ) as cur:
                async with cur.begin():
                    await cur.execute('SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')
                    res = await fun(cur)
                cur.close()  # is this really needed? TODO
                return res
        except psycopg2.errors.SerializationFailure:
            # See https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/409
            return DBResponse(response_code=409, body="Conflicting concurrent tag mutation, please retry")
        except (Exception, psycopg2.DatabaseError) as error:
            self.db.logger.exception("Exception occurred")
            return aiopg_exception_handling(error)

    async def update_row(self, filter_dict={}, update_dict={}, cur: aiopg.Cursor = None):
        # generate where clause
        filters = []
        for col_name, col_val in filter_dict.items():
            operator = '='
            v = str(col_val).strip("'")
            if not v.isnumeric():
                v = "'" + v + "'"
            find_operator = operator_match.match(col_name)
            if find_operator:
                col_name = find_operator.group(1)
                operator = find_operator.group(2)
                filters.append('(%s IS NULL or %s %s %s)' %
                               (col_name, col_name, operator, str(v)))
            else:
                filters.append(col_name + operator + str(v))

        seperator = " and "
        where_clause = ""
        if bool(filter_dict):
            where_clause = seperator.join(filters)

        sets = []
        for col_name, col_val in update_dict.items():
            sets.append(col_name + " = " + str(col_val))

        set_seperator = ", "
        set_clause = ""
        if bool(filter_dict):
            set_clause = set_seperator.join(sets)
        update_sql = """
                UPDATE {0} SET {1} WHERE {2};
        """.format(self.table_name, set_clause, where_clause)

        async def _execute_update_on_cursor(_cur):
            await _cur.execute(update_sql)
            if _cur.rowcount < 1:
                return DBResponse(response_code=404,
                                  body={"msg": "could not find row"})
            if _cur.rowcount > 1:
                return DBResponse(response_code=500,
                                  body={"msg": "duplicate rows"})
            return DBResponse(response_code=200, body={"rowcount": _cur.rowcount})
        if cur:
            return await _execute_update_on_cursor(cur)
        try:
            with (
                await self.db.pool.cursor(
                    cursor_factory=psycopg2.extras.DictCursor
                )
            ) as cur:
                db_response = await _execute_update_on_cursor(cur)
                cur.close()
                return db_response
        except (Exception, psycopg2.DatabaseError) as error:
            self.db.logger.exception("Exception occurred")
            return aiopg_exception_handling(error)


class PostgresUtils(object):
    @staticmethod
    async def create_trigger_if_missing(db: _AsyncPostgresDB, table_name, trigger_name, commands=[]):
        "executes the commands only if a trigger with the given name does not already exist on the table"
        with (await db.pool.cursor()) as cur:
            try:
                await cur.execute(
                    """
                    SELECT *
                    FROM information_schema.triggers
                    WHERE event_object_table = %s
                    AND trigger_name = %s
                    """,
                    (table_name, trigger_name),
                )
                trigger_exist = bool(cur.rowcount)
                if not trigger_exist:
                    for command in commands:
                        await cur.execute(command)
            finally:
                cur.close()

    @staticmethod
    async def cleanup_triggers(db: _AsyncPostgresDB, table_name, schema=DB_SCHEMA_NAME):
        "Cleans up old versions of table triggers"
        with (await db.pool.cursor()) as cur:
            try:
                await cur.execute(
                    """
                    SELECT DISTINCT trigger_name
                    FROM information_schema.triggers
                    WHERE event_object_table = %s
                    AND trigger_schema = %s
                    """,
                    [table_name, schema]
                )
                results = await cur.fetchall()

                triggers_to_cleanup = [
                    res[0] for res in results
                    if res[0].startswith(TRIGGER_NAME_PREFIX) and TRIGGER_VERSION not in res[0]
                ]
                if triggers_to_cleanup:
                    logging.getLogger("TriggerSetup").info("Cleaning up old triggers: %s" % triggers_to_cleanup)
                    commands = []
                    for trigger_name in triggers_to_cleanup:
                        commands += [
                            (f"DROP TRIGGER IF EXISTS {trigger_name} ON {schema}.{table_name}"),
                            (f"DROP FUNCTION IF EXISTS {schema}.{trigger_name}")
                        ]

                    for command in commands:
                        await cur.execute(command)
            finally:
                cur.close()

    @staticmethod
    async def setup_trigger_notify(
        db: _AsyncPostgresDB,
        table_name,
        keys: List[str] = None,
        schema=DB_SCHEMA_NAME,
        operations: List[str] = None,
        conditions: List[str] = None
    ):
        if not keys:
            pass

        name_prefix = "%s_%s" % (TRIGGER_NAME_PREFIX, TRIGGER_VERSION)
        operations = operations
        _commands = ["""
        CREATE OR REPLACE FUNCTION {schema}.{prefix}_{table}() RETURNS trigger
            LANGUAGE plpgsql
            AS $$
        DECLARE
            rec RECORD;
            BEGIN

            CASE TG_OP
            WHEN 'INSERT', 'UPDATE' THEN
                rec := NEW;
            WHEN 'DELETE' THEN
                rec := OLD;
            ELSE
                RAISE EXCEPTION 'Unknown TG_OP: "%"', TG_OP;
            END CASE;

            PERFORM pg_notify('notify', json_build_object(
                            'table',     TG_TABLE_NAME,
                            'schema',    TG_TABLE_SCHEMA,
                            'operation', TG_OP,
                            'data',      json_build_object({keys})
                    )::text);
            RETURN rec;
            END;
        $$;
        """.format(
            schema=schema,
            prefix=name_prefix,
            table=table_name,
            keys=", ".join(map(lambda k: "'{0}', rec.{0}".format(k), keys)),
        )]

        _commands += ["""
            CREATE TRIGGER {prefix}_{table} AFTER {events} ON {schema}.{table}
                FOR EACH ROW {conditions} EXECUTE PROCEDURE {schema}.{prefix}_{table}();
            """.format(
            schema=schema,
            prefix=name_prefix,
            table=table_name,
            events=" OR ".join(operations),
            conditions="WHEN (%s)" % " OR ".join(conditions) if conditions else ""
        )]

        # This enables trigger on both replica and non-replica mode
        _commands += ["ALTER TABLE {schema}.{table} ENABLE ALWAYS TRIGGER {prefix}_{table};".format(
            schema=schema,
            prefix=name_prefix,
            table=table_name
        )]

        # NOTE: Only try to setup triggers if they do not already exist.
        # This will require a table level lock so it should be performed during initial setup at off-peak hours.
        await PostgresUtils.create_trigger_if_missing(
            db=db,
            table_name=table_name,
            trigger_name="{}_{}".format(name_prefix, table_name),
            commands=_commands
        )


class AsyncFlowTablePostgres(AsyncPostgresTable):
    flow_dict = {}
    table_name = FLOW_TABLE_NAME
    keys = ["flow_id", "user_name", "ts_epoch", "tags", "system_tags"]
    primary_keys = ["flow_id"]
    trigger_keys = primary_keys
    select_columns = keys
    _row_type = FlowRow

    async def add_flow(self, flow: FlowRow):
        dict = {
            "flow_id": flow.flow_id,
            "user_name": flow.user_name,
            "tags": json.dumps(flow.tags),
            "system_tags": json.dumps(flow.system_tags),
        }
        return await self.create_record(dict)

    async def get_flow(self, flow_id: str):
        filter_dict = {"flow_id": flow_id}
        return await self.get_records(filter_dict=filter_dict, fetch_single=True)

    async def get_all_flows(self):
        return await self.get_records()


class AsyncRunTablePostgres(AsyncPostgresTable):
    run_dict = {}
    run_by_flow_dict = {}
    _current_count = 0
    _row_type = RunRow
    table_name = RUN_TABLE_NAME
    keys = ["flow_id", "run_number", "run_id",
            "user_name", "ts_epoch", "last_heartbeat_ts", "tags", "system_tags"]
    primary_keys = ["flow_id", "run_number"]
    trigger_keys = primary_keys + ["last_heartbeat_ts"]
    select_columns = keys
    flow_table_name = AsyncFlowTablePostgres.table_name

    async def add_run(self, run: RunRow, fill_heartbeat: bool = False):
        dict = {
            "flow_id": run.flow_id,
            "user_name": run.user_name,
            "tags": json.dumps(run.tags),
            "system_tags": json.dumps(run.system_tags),
            "run_id": run.run_id,
            "last_heartbeat_ts": str(new_heartbeat_ts()) if fill_heartbeat else None
        }
        return await self.create_record(dict)

    async def get_run(self, flow_id: str, run_id: str, expanded: bool = False, cur: aiopg.Cursor = None):
        key, value = translate_run_key(run_id)
        filter_dict = {"flow_id": flow_id, key: str(value)}
        return await self.get_records(filter_dict=filter_dict,
                                      fetch_single=True, expanded=expanded, cur=cur)

    async def get_all_runs(self, flow_id: str):
        filter_dict = {"flow_id": flow_id}
        return await self.get_records(filter_dict=filter_dict)

    async def update_heartbeat(self, flow_id: str, run_id: str):
        run_key, run_value = translate_run_key(run_id)
        new_hb = new_heartbeat_ts()
        filter_dict = {"flow_id": flow_id,
                       run_key: str(run_value),
                       "last_heartbeat_ts:<=": new_hb - WAIT_TIME}
        set_dict = {
            "last_heartbeat_ts": new_hb
        }
        result = await self.update_row(filter_dict=filter_dict,
                                       update_dict=set_dict)
        body = {"wait_time_in_seconds": WAIT_TIME}

        return DBResponse(response_code=result.response_code,
                          body=json.dumps(body))

    async def update_run_tags(self, flow_id: str, run_id: str, run_tags: list, cur: aiopg.Cursor = None):
        run_key, run_value = translate_run_key(run_id)
        filter_dict = {"flow_id": flow_id,
                       run_key: str(run_value)}

        set_dict = {"tags": QuotedString(json.dumps(run_tags)).getquoted().decode()}
        return await self.update_row(filter_dict=filter_dict,
                                     update_dict=set_dict,
                                     cur=cur)


class AsyncStepTablePostgres(AsyncPostgresTable):
    step_dict = {}
    run_to_step_dict = {}
    _row_type = StepRow
    table_name = STEP_TABLE_NAME
    keys = ["flow_id", "run_number", "run_id", "step_name",
            "user_name", "ts_epoch", "tags", "system_tags"]
    primary_keys = ["flow_id", "run_number", "step_name"]
    trigger_keys = primary_keys
    select_columns = keys
    run_table_name = AsyncRunTablePostgres.table_name

    async def add_step(self, step_object: StepRow):
        dict = {
            "flow_id": step_object.flow_id,
            "run_number": str(step_object.run_number),
            "run_id": step_object.run_id,
            "step_name": step_object.step_name,
            "user_name": step_object.user_name,
            "tags": json.dumps(step_object.tags),
            "system_tags": json.dumps(step_object.system_tags),
        }
        return await self.create_record(dict)

    async def get_steps(self, flow_id: str, run_id: str):
        run_id_key, run_id_value = translate_run_key(run_id)
        filter_dict = {"flow_id": flow_id,
                       run_id_key: run_id_value}
        return await self.get_records(filter_dict=filter_dict)

    async def get_step(self, flow_id: str, run_id: str, step_name: str):
        run_id_key, run_id_value = translate_run_key(run_id)
        filter_dict = {
            "flow_id": flow_id,
            run_id_key: run_id_value,
            "step_name": step_name,
        }
        return await self.get_records(filter_dict=filter_dict, fetch_single=True)


class AsyncTaskTablePostgres(AsyncPostgresTable):
    task_dict = {}
    step_to_task_dict = {}
    _current_count = 0
    _row_type = TaskRow
    table_name = TASK_TABLE_NAME
    keys = ["flow_id", "run_number", "run_id", "step_name", "task_id",
            "task_name", "user_name", "ts_epoch", "last_heartbeat_ts", "tags", "system_tags"]
    primary_keys = ["flow_id", "run_number", "step_name", "task_id"]
    trigger_keys = primary_keys
    select_columns = keys
    step_table_name = AsyncStepTablePostgres.table_name

    async def add_task(self, task: TaskRow, fill_heartbeat=False):
        # todo backfill run_number if missing?
        dict = {
            "flow_id": task.flow_id,
            "run_number": str(task.run_number),
            "run_id": task.run_id,
            "step_name": task.step_name,
            "task_name": task.task_name,
            "user_name": task.user_name,
            "tags": json.dumps(task.tags),
            "system_tags": json.dumps(task.system_tags),
            "last_heartbeat_ts": str(new_heartbeat_ts()) if fill_heartbeat else None
        }
        return await self.create_record(dict)

    async def get_tasks(self, flow_id: str, run_id: str, step_name: str):
        run_id_key, run_id_value = translate_run_key(run_id)
        filter_dict = {
            "flow_id": flow_id,
            run_id_key: run_id_value,
            "step_name": step_name,
        }
        return await self.get_records(filter_dict=filter_dict)

    async def get_task(self, flow_id: str, run_id: str, step_name: str,
                       task_id: str, expanded: bool = False):
        run_id_key, run_id_value = translate_run_key(run_id)
        task_id_key, task_id_value = translate_task_key(task_id)
        filter_dict = {
            "flow_id": flow_id,
            run_id_key: run_id_value,
            "step_name": step_name,
            task_id_key: task_id_value,
        }
        return await self.get_records(filter_dict=filter_dict,
                                      fetch_single=True, expanded=expanded)

    async def update_heartbeat(self, flow_id: str, run_id: str, step_name: str,
                               task_id: str):
        run_key, run_value = translate_run_key(run_id)
        task_key, task_value = translate_task_key(task_id)
        new_hb = new_heartbeat_ts()
        filter_dict = {"flow_id": flow_id,
                       run_key: str(run_value),
                       "step_name": step_name,
                       task_key: str(task_value),
                       "last_heartbeat_ts:<=": new_hb - WAIT_TIME}
        set_dict = {
            "last_heartbeat_ts": new_hb
        }
        result = await self.update_row(filter_dict=filter_dict,
                                       update_dict=set_dict)

        body = {"wait_time_in_seconds": WAIT_TIME}

        return DBResponse(response_code=result.response_code,
                          body=json.dumps(body))


class AsyncMetadataTablePostgres(AsyncPostgresTable):
    metadata_dict = {}
    run_to_metadata_dict = {}
    _current_count = 0
    _row_type = MetadataRow
    table_name = METADATA_TABLE_NAME
    keys = ["flow_id", "run_number", "run_id", "step_name", "task_id", "task_name", "id",
            "field_name", "value", "type", "user_name", "ts_epoch", "tags", "system_tags"]
    primary_keys = ["flow_id", "run_number",
                    "step_name", "task_id", "field_name"]
    trigger_keys = ["flow_id", "run_number",
                    "step_name", "task_id", "field_name", "value", "tags"]
    trigger_operations = ["INSERT"]
    select_columns = keys

    async def add_metadata(
        self,
        flow_id,
        run_number,
        run_id,
        step_name,
        task_id,
        task_name,
        field_name,
        value,
        type,
        user_name,
        tags,
        system_tags,
    ):
        dict = {
            "flow_id": flow_id,
            "run_number": str(run_number),
            "run_id": run_id,
            "step_name": step_name,
            "task_id": str(task_id),
            "task_name": task_name,
            "field_name": field_name,
            "value": value,
            "type": type,
            "user_name": user_name,
            "tags": json.dumps(tags),
            "system_tags": json.dumps(system_tags),
        }
        return await self.create_record(dict)

    async def get_metadata_in_runs(self, flow_id: str, run_id: str):
        run_id_key, run_id_value = translate_run_key(run_id)
        filter_dict = {"flow_id": flow_id,
                       run_id_key: run_id_value}
        return await self.get_records(filter_dict=filter_dict)

    async def get_metadata(
        self, flow_id: str, run_id: int, step_name: str, task_id: str
    ):
        run_id_key, run_id_value = translate_run_key(run_id)
        task_id_key, task_id_value = translate_task_key(task_id)
        filter_dict = {
            "flow_id": flow_id,
            run_id_key: run_id_value,
            "step_name": step_name,
            task_id_key: task_id_value,
        }
        return await self.get_records(filter_dict=filter_dict)

    async def get_filtered_task_pathspecs(self, flow_id: str, run_id: str, step_name: str, field_name: str, pattern: str):
        """
        Returns a list of task pathspecs that match the given field_name and regexp pattern for the value
        """
        run_id_key, run_id_value = translate_run_key(run_id)
        filter_dict = {
            "flow_id": flow_id,
            run_id_key: run_id_value,
            "step_name": step_name,
        }
        conditions = [f"{k} = %s" for k, v in filter_dict.items() if v is not None]
        values = [v for k, v in filter_dict.items() if v is not None]

        if field_name:
            conditions.append("field_name = %s")
            values.append(field_name)

        if pattern:
            conditions.append("regexp_match(value, %s) IS NOT NULL")
            values.append(pattern)

        # We must return distinct task pathspecs, so we construct the select statement by hand
        sql_template = """
        SELECT DISTINCT {select_columns} FROM (
            SELECT
                {keys}
            FROM {table_name}
        ) T
        {where}
        {order_by}
        """

        select_sql = sql_template.format(
            keys=",".join(self.select_columns),
            table_name=self.table_name,
            where="WHERE {}".format(" AND ".join(conditions)),
            order_by="ORDER BY task_id",
            select_columns=",".join(["flow_id, run_number, run_id, step_name, task_name, task_id"])
        ).strip()

        db_response, pagination = await self.execute_sql(select_sql=select_sql, values=values, serialize=False)

        # flatten the ids in the response
        def _format_id(row):
            flow_id, run_number, run_id, step_name, task_name, task_id = row
            # pathspec
            return f"{flow_id}/{run_id or run_number}/{step_name}/{task_name or task_id}"

        flattened_response = DBResponse(body=[_format_id(row) for row in db_response.body], response_code=db_response.response_code)
        return flattened_response, pagination


class AsyncArtifactTablePostgres(AsyncPostgresTable):
    artifact_dict = {}
    run_to_artifact_dict = {}
    step_to_artifact_dict = {}
    task_to_artifact_dict = {}
    current_count = 0
    _row_type = ArtifactRow
    table_name = ARTIFACT_TABLE_NAME
    ordering = ["attempt_id DESC"]
    keys = ["flow_id", "run_number", "run_id", "step_name", "task_id", "task_name", "name", "location",
            "ds_type", "sha", "type", "content_type", "user_name", "attempt_id", "ts_epoch", "tags", "system_tags"]
    primary_keys = ["flow_id", "run_number",
                    "step_name", "task_id", "attempt_id", "name"]
    trigger_keys = primary_keys
    trigger_operations = ["INSERT"]
    select_columns = keys

    async def add_artifact(
        self,
        flow_id,
        run_number,
        run_id,
        step_name,
        task_id,
        task_name,
        name,
        location,
        ds_type,
        sha,
        type,
        content_type,
        user_name,
        attempt_id,
        tags,
        system_tags,
    ):
        dict = {
            "flow_id": flow_id,
            "run_number": str(run_number),
            "run_id": run_id,
            "step_name": step_name,
            "task_id": str(task_id),
            "task_name": task_name,
            "name": name,
            "location": location,
            "ds_type": ds_type,
            "sha": sha,
            "type": type,
            "content_type": content_type,
            "user_name": user_name,
            "attempt_id": str(attempt_id),
            "tags": json.dumps(tags),
            "system_tags": json.dumps(system_tags),
        }
        return await self.create_record(dict)

    async def get_artifacts_in_runs(self, flow_id: str, run_id: int):
        run_id_key, run_id_value = translate_run_key(run_id)
        filter_dict = {
            "flow_id": flow_id,
            run_id_key: run_id_value,
        }
        return await self.get_records(filter_dict=filter_dict,
                                      ordering=self.ordering)

    async def get_artifact_in_steps(self, flow_id: str, run_id: int, step_name: str):
        run_id_key, run_id_value = translate_run_key(run_id)
        filter_dict = {
            "flow_id": flow_id,
            run_id_key: run_id_value,
            "step_name": step_name,
        }
        return await self.get_records(filter_dict=filter_dict,
                                      ordering=self.ordering)

    async def get_artifact_in_task(
        self, flow_id: str, run_id: int, step_name: str, task_id: int
    ):
        run_id_key, run_id_value = translate_run_key(run_id)
        task_id_key, task_id_value = translate_task_key(task_id)
        filter_dict = {
            "flow_id": flow_id,
            run_id_key: run_id_value,
            "step_name": step_name,
            task_id_key: task_id_value,
        }
        return await self.get_records(filter_dict=filter_dict,
                                      ordering=self.ordering)

    async def get_artifact(
        self, flow_id: str, run_id: int, step_name: str, task_id: int, name: str
    ):
        # Return the artifact metadata for the latest attempt of the task.
        #
        # The quirk here is that different attempts may have different sets of
        # artifacts. That is, if artifact "foo" was set in attempt N, that
        # doesn't mean it was set in attempt N+1, and vice versa.
        #
        # To get the artifact value for the "latest" attempt, we first find
        # the latest attempt_id by querying the artifacts table for the
        # artifact that always exists for every attempt (the artifact called
        # 'name', containing the flow name), then use that attempt_id to get
        # the artifact we're interested in.
        run_id_key, run_id_value = translate_run_key(run_id)
        task_id_key, task_id_value = translate_task_key(task_id)
        filter_dict = {
            "flow_id": flow_id,
            run_id_key: run_id_value,
            "step_name": step_name,
            task_id_key: task_id_value,
            '"name"': "name"
        }
        name_record = await self.get_records(filter_dict=filter_dict,
                                             fetch_single=True, ordering=self.ordering)

        return await self.get_artifact_by_attempt(
            flow_id, run_id, step_name, task_id, name, name_record.body.get('attempt_id', 0))

    async def get_artifact_by_attempt(
            self, flow_id: str, run_id: int, step_name: str, task_id: int, name: str,
            attempt: int):

        run_id_key, run_id_value = translate_run_key(run_id)
        task_id_key, task_id_value = translate_task_key(task_id)
        filter_dict = {
            "flow_id": flow_id,
            run_id_key: run_id_value,
            "step_name": step_name,
            task_id_key: task_id_value,
            '"name"': name,
            '"attempt_id"': attempt
        }
        return await self.get_records(filter_dict=filter_dict,
                                      fetch_single=True, ordering=self.ordering)
