superset/daos/key_value.py (113 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import annotations import logging from datetime import datetime from typing import Any from uuid import UUID from sqlalchemy import and_ from superset import db from superset.daos.base import BaseDAO from superset.key_value.exceptions import ( KeyValueCreateFailedError, KeyValueUpdateFailedError, ) from superset.key_value.models import KeyValueEntry from superset.key_value.types import Key, KeyValueCodec, KeyValueResource from superset.key_value.utils import get_filter from superset.utils.core import get_user_id logger = logging.getLogger(__name__) class KeyValueDAO(BaseDAO[KeyValueEntry]): @staticmethod def get_entry( resource: KeyValueResource, key: Key, ) -> KeyValueEntry | None: filter_ = get_filter(resource, key) return db.session.query(KeyValueEntry).filter_by(**filter_).first() @classmethod def get_value( cls, resource: KeyValueResource, key: Key, codec: KeyValueCodec, ) -> Any: entry = cls.get_entry(resource, key) if not entry or entry.is_expired(): return None return codec.decode(entry.value) @staticmethod def delete_entry(resource: KeyValueResource, key: Key) -> bool: if entry := KeyValueDAO.get_entry(resource, key): db.session.delete(entry) return True return False @staticmethod def delete_expired_entries(resource: KeyValueResource) -> None: ( db.session.query(KeyValueEntry) .filter( and_( KeyValueEntry.resource == resource.value, KeyValueEntry.expires_on <= datetime.now(), ) ) .delete() ) @staticmethod def create_entry( resource: KeyValueResource, value: Any, codec: KeyValueCodec, key: Key | None = None, expires_on: datetime | None = None, ) -> KeyValueEntry: try: encoded_value = codec.encode(value) except Exception as ex: raise KeyValueCreateFailedError("Unable to encode value") from ex entry = KeyValueEntry( resource=resource.value, value=encoded_value, created_on=datetime.now(), created_by_fk=get_user_id(), expires_on=expires_on, ) if key is not None: try: if isinstance(key, UUID): entry.uuid = key else: entry.id = key except ValueError as ex: raise KeyValueCreateFailedError() from ex db.session.add(entry) return entry @staticmethod def upsert_entry( resource: KeyValueResource, value: Any, codec: KeyValueCodec, key: Key, expires_on: datetime | None = None, ) -> KeyValueEntry: if entry := KeyValueDAO.get_entry(resource, key): entry.value = codec.encode(value) entry.expires_on = expires_on entry.changed_on = datetime.now() entry.changed_by_fk = get_user_id() return entry return KeyValueDAO.create_entry(resource, value, codec, key, expires_on) @staticmethod def update_entry( resource: KeyValueResource, value: Any, codec: KeyValueCodec, key: Key, expires_on: datetime | None = None, ) -> KeyValueEntry: if entry := KeyValueDAO.get_entry(resource, key): entry.value = codec.encode(value) entry.expires_on = expires_on entry.changed_on = datetime.now() entry.changed_by_fk = get_user_id() return entry raise KeyValueUpdateFailedError()