pyignite/transaction.py (81 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from enum import IntEnum from typing import Union, Type from pyignite.api.tx_api import tx_end, tx_start, tx_end_async, tx_start_async from pyignite.datatypes import TransactionIsolation, TransactionConcurrency from pyignite.exceptions import CacheError from pyignite.utils import status_to_exception def _validate_int_enum_param(value: Union[int, IntEnum], cls: Type[IntEnum]): if value not in set(v.value for v in cls): # Use this trick to disable warning on python 3.7 raise ValueError(f'{value} not in {cls}') return value def _validate_timeout(value): if not isinstance(value, int) or value < 0: raise ValueError(f'Timeout value should be a positive integer, {value} passed instead') return value def _validate_label(value): if value and not isinstance(value, str): raise ValueError(f'Label should be str, {type(value)} passed instead') return value class _BaseTransaction: def __init__(self, client, concurrency=TransactionConcurrency.PESSIMISTIC, isolation=TransactionIsolation.REPEATABLE_READ, timeout=0, label=None): self.client = client self.concurrency = _validate_int_enum_param(concurrency, TransactionConcurrency) self.isolation = _validate_int_enum_param(isolation, TransactionIsolation) self.timeout = _validate_timeout(timeout) self.label, self.closed = _validate_label(label), False class Transaction(_BaseTransaction): """ Thin client transaction. """ def __init__(self, client, concurrency=TransactionConcurrency.PESSIMISTIC, isolation=TransactionIsolation.REPEATABLE_READ, timeout=0, label=None): super().__init__(client, concurrency, isolation, timeout, label) self.tx_id = self.__start_tx() def commit(self) -> None: """ Commit transaction. """ if not self.closed: self.closed = True return self.__end_tx(True) def rollback(self) -> None: """ Rollback transaction. """ self.close() def close(self) -> None: """ Close transaction. """ if not self.closed: self.closed = True return self.__end_tx(False) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() @status_to_exception(CacheError) def __start_tx(self): conn = self.client.random_node return tx_start(conn, self.concurrency, self.isolation, self.timeout, self.label) @status_to_exception(CacheError) def __end_tx(self, committed): return tx_end(self.tx_id, committed) class AioTransaction(_BaseTransaction): """ Async thin client transaction. """ def __init__(self, client, concurrency=TransactionConcurrency.PESSIMISTIC, isolation=TransactionIsolation.REPEATABLE_READ, timeout=0, label=None): super().__init__(client, concurrency, isolation, timeout, label) def __await__(self): return (yield from self.__aenter__().__await__()) async def commit(self) -> None: """ Commit transaction. """ if not self.closed: self.closed = True return await self.__end_tx(True) async def rollback(self) -> None: """ Rollback transaction. """ await self.close() async def close(self) -> None: """ Close transaction. """ if not self.closed: self.closed = True return await self.__end_tx(False) async def __aenter__(self): self.tx_id = await self.__start_tx() self.closed = False return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() @status_to_exception(CacheError) async def __start_tx(self): conn = await self.client.random_node() return await tx_start_async(conn, self.concurrency, self.isolation, self.timeout, self.label) @status_to_exception(CacheError) async def __end_tx(self, committed): return await tx_end_async(self.tx_id, committed)