pyignite/cursors.py (225 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This module contains sync and async cursors for different types of queries. """ import asyncio from pyignite.api import ( scan, scan_cursor_get_page, resource_close, scan_async, scan_cursor_get_page_async, resource_close_async, sql, sql_cursor_get_page, sql_fields, sql_fields_cursor_get_page, sql_fields_cursor_get_page_async, sql_fields_async ) from pyignite.exceptions import CacheError, SQLError __all__ = ['ScanCursor', 'SqlCursor', 'SqlFieldsCursor', 'AioScanCursor', 'AioSqlFieldsCursor'] class BaseCursorMixin: @property def connection(self): """ Ignite cluster connection. """ return getattr(self, '_conn', None) @connection.setter def connection(self, value): setattr(self, '_conn', value) @property def cursor_id(self): """ Cursor id. """ return getattr(self, '_cursor_id', None) @cursor_id.setter def cursor_id(self, value): setattr(self, '_cursor_id', value) @property def more(self): """ Whether cursor has more values. """ return getattr(self, '_more', None) @more.setter def more(self, value): setattr(self, '_more', value) @property def cache_info(self): """ Cache id. """ return getattr(self, '_cache_info', None) @cache_info.setter def cache_info(self, value): setattr(self, '_cache_info', value) @property def client(self): """ Apache Ignite client. """ return getattr(self, '_client', None) @client.setter def client(self, value): setattr(self, '_client', value) @property def data(self): """ Current fetched data. """ return getattr(self, '_data', None) @data.setter def data(self, value): setattr(self, '_data', value) class CursorMixin(BaseCursorMixin): def __enter__(self): return self def __iter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() def close(self): """ Close cursor. """ if self.connection and self.cursor_id and self.more: resource_close(self.connection, self.cursor_id) class AioCursorMixin(BaseCursorMixin): def __await__(self): return (yield from self.__aenter__().__await__()) def __aiter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() async def close(self): """ Close cursor. """ if self.connection and self.cursor_id and self.more: await resource_close_async(self.connection, self.cursor_id) class AbstractScanCursor: def __init__(self, client, cache_info, page_size, partitions, local): self.client = client self.cache_info = cache_info self._page_size = page_size self._partitions = partitions self._local = local def _finalize_init(self, result): if result.status != 0: raise CacheError(result.message) self.cursor_id, self.more = result.value['cursor'], result.value['more'] self.data = iter(result.value['data'].items()) def _process_page_response(self, result): if result.status != 0: raise CacheError(result.message) self.data, self.more = iter(result.value['data'].items()), result.value['more'] class ScanCursor(AbstractScanCursor, CursorMixin): """ Synchronous scan cursor. """ def __init__(self, client, cache_info, page_size, partitions, local): """ :param client: Synchronous Apache Ignite client. :param cache_info: Cache meta info. :param page_size: page size. :param partitions: number of partitions to query (negative to query entire cache). :param local: pass True if this query should be executed on local node only. """ super().__init__(client, cache_info, page_size, partitions, local) self.connection = self.client.random_node result = scan(self.connection, self.cache_info, self._page_size, self._partitions, self._local) self._finalize_init(result) def __next__(self): if not self.data: raise StopIteration try: k, v = next(self.data) except StopIteration: if self.more: self._process_page_response(scan_cursor_get_page(self.connection, self.cursor_id)) k, v = next(self.data) else: raise StopIteration return self.client.unwrap_binary(k), self.client.unwrap_binary(v) class AioScanCursor(AbstractScanCursor, AioCursorMixin): """ Asynchronous scan query cursor. """ def __init__(self, client, cache_info, page_size, partitions, local): """ :param client: Asynchronous Apache Ignite client. :param cache_info: Cache meta info. :param page_size: page size. :param partitions: number of partitions to query (negative to query entire cache). :param local: pass True if this query should be executed on local node only. """ super().__init__(client, cache_info, page_size, partitions, local) async def __aenter__(self): if not self.connection: self.connection = await self.client.random_node() result = await scan_async(self.connection, self.cache_info, self._page_size, self._partitions, self._local) self._finalize_init(result) return self async def __anext__(self): if not self.connection: raise CacheError("Using uninitialized cursor, initialize it using async with expression.") if not self.data: raise StopAsyncIteration try: k, v = next(self.data) except StopIteration: if self.more: self._process_page_response(await scan_cursor_get_page_async(self.connection, self.cursor_id)) try: k, v = next(self.data) except StopIteration: raise StopAsyncIteration else: raise StopAsyncIteration return await asyncio.gather( *[self.client.unwrap_binary(k), self.client.unwrap_binary(v)] ) class SqlCursor(CursorMixin): """ Synchronous SQL query cursor. """ def __init__(self, client, cache_info, *args, **kwargs): """ :param client: Synchronous Apache Ignite client. :param cache_info: Cache meta info. """ self.client = client self.cache_info = cache_info self.connection = self.client.random_node result = sql(self.connection, self.cache_info, *args, **kwargs) if result.status != 0: raise SQLError(result.message) self.cursor_id, self.more = result.value['cursor'], result.value['more'] self.data = iter(result.value['data'].items()) def __next__(self): if not self.data: raise StopIteration try: k, v = next(self.data) except StopIteration: if self.more: result = sql_cursor_get_page(self.connection, self.cursor_id) if result.status != 0: raise SQLError(result.message) self.data, self.more = iter(result.value['data'].items()), result.value['more'] k, v = next(self.data) else: raise StopIteration return self.client.unwrap_binary(k), self.client.unwrap_binary(v) class AbstractSqlFieldsCursor: def __init__(self, client, cache_info): self.client = client self.cache_info = cache_info def _finalize_init(self, result): if result.status != 0: raise SQLError(result.message) self.cursor_id, self.more = result.value['cursor'], result.value['more'] self.data = iter(result.value['data']) self._field_names = result.value.get('fields', None) if self._field_names: self._field_count = len(self._field_names) else: self._field_count = result.value['field_count'] class SqlFieldsCursor(AbstractSqlFieldsCursor, CursorMixin): """ Synchronous SQL fields query cursor. """ def __init__(self, client, cache_info, *args, **kwargs): """ :param client: Synchronous Apache Ignite client. :param cache_info: Cache meta info. """ super().__init__(client, cache_info) self.connection = self.client.random_node self._finalize_init(sql_fields(self.connection, self.cache_info, *args, **kwargs)) def __next__(self): if not self.data: raise StopIteration if self._field_names: result = self._field_names self._field_names = None return result try: row = next(self.data) except StopIteration: if self.more: result = sql_fields_cursor_get_page(self.connection, self.cursor_id, self._field_count) if result.status != 0: raise SQLError(result.message) self.data, self.more = iter(result.value['data']), result.value['more'] row = next(self.data) else: raise StopIteration return [self.client.unwrap_binary(v) for v in row] class AioSqlFieldsCursor(AbstractSqlFieldsCursor, AioCursorMixin): """ Asynchronous SQL fields query cursor. """ def __init__(self, client, cache_info, *args, **kwargs): """ :param client: Synchronous Apache Ignite client. :param cache_info: Cache meta info. """ super().__init__(client, cache_info) self._params = (args, kwargs) async def __aenter__(self): await self._initialize(*self._params[0], *self._params[1]) return self async def __anext__(self): if not self.connection: raise SQLError("Attempting to use uninitialized aio cursor, please await on it or use with expression.") if not self.data: raise StopAsyncIteration if self._field_names: result = self._field_names self._field_names = None return result try: row = next(self.data) except StopIteration: if self.more: result = await sql_fields_cursor_get_page_async(self.connection, self.cursor_id, self._field_count) if result.status != 0: raise SQLError(result.message) self.data, self.more = iter(result.value['data']), result.value['more'] try: row = next(self.data) except StopIteration: raise StopAsyncIteration else: raise StopAsyncIteration return await asyncio.gather(*[self.client.unwrap_binary(v) for v in row]) async def _initialize(self, *args, **kwargs): if self.connection and self.cursor_id: return self.connection = await self.client.random_node() self._finalize_init(await sql_fields_async(self.connection, self.cache_info, *args, **kwargs))