src/google/appengine/ext/ndb/context.py (832 lines of code) (raw):

#!/usr/bin/env python # # Copyright 2007 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Context class.""" import logging import sys from google.appengine.ext.ndb import eventloop from google.appengine.ext.ndb import key as key_module from google.appengine.ext.ndb import model from google.appengine.ext.ndb import tasklets from google.appengine.ext.ndb import utils import six from six.moves import range from six.moves import zip from google.appengine.api import datastore from google.appengine.api import datastore_errors from google.appengine.api import memcache from google.appengine.api import namespace_manager from google.appengine.api import urlfetch from google.appengine.datastore import datastore_rpc from google.protobuf import message from google.appengine.datastore import entity_bytes_pb2 as entity_pb2 __all__ = ['Context', 'ContextOptions', 'TransactionOptions', 'AutoBatcher', 'EVENTUAL_CONSISTENCY', ] _LOCK_TIME = 32 _LOCKED = 0 EVENTUAL_CONSISTENCY = datastore_rpc.Configuration.EVENTUAL_CONSISTENCY class ContextOptions(datastore_rpc.Configuration): """Configuration options that may be passed along with get/put/delete.""" @datastore_rpc.ConfigOption def use_cache(value): if not isinstance(value, bool): raise datastore_errors.BadArgumentError( 'use_cache should be a bool (%r)' % (value,)) return value @datastore_rpc.ConfigOption def use_memcache(value): if not isinstance(value, bool): raise datastore_errors.BadArgumentError( 'use_memcache should be a bool (%r)' % (value,)) return value @datastore_rpc.ConfigOption def use_datastore(value): if not isinstance(value, bool): raise datastore_errors.BadArgumentError( 'use_datastore should be a bool (%r)' % (value,)) return value @datastore_rpc.ConfigOption def memcache_timeout(value): if not isinstance(value, six.integer_types): raise datastore_errors.BadArgumentError( 'memcache_timeout should be an integer (%r)' % (value,)) return value @datastore_rpc.ConfigOption def max_memcache_items(value): if not isinstance(value, six.integer_types): raise datastore_errors.BadArgumentError( 'max_memcache_items should be an integer (%r)' % (value,)) return value @datastore_rpc.ConfigOption def memcache_deadline(value): if not isinstance(value, six.integer_types): raise datastore_errors.BadArgumentError( 'memcache_deadline should be an integer (%r)' % (value,)) return value class TransactionOptions(ContextOptions, datastore_rpc.TransactionOptions): """Support both context options and transaction options.""" _OPTION_TRANSLATIONS = { 'options': 'config', } def _make_ctx_options(ctx_options, config_cls=ContextOptions): """Helper to construct a ContextOptions object from keyword arguments. Args: ctx_options: A dict of keyword arguments. config_cls: Optional Configuration class to use, default ContextOptions. Note that either 'options' or 'config' can be used to pass another Configuration object, but not both. If another Configuration object is given it provides default values. Returns: A Configuration object, or None if ctx_options is empty. """ if not ctx_options: return None for key in list(ctx_options): translation = _OPTION_TRANSLATIONS.get(key) if translation: if translation in ctx_options: raise ValueError('Cannot specify %s and %s at the same time' % (key, translation)) ctx_options[translation] = ctx_options.pop(key) return config_cls(**ctx_options) class AutoBatcher(object): """Batches multiple async calls if they share the same RPC options. Here is an example to explain what this class does. Life of a `key.get_async(options)` API call: * `Key` gets the singleton `Context` instance and invokes `Context.get`. * `Context.get` calls `Context._get_batcher.add(key, options)`. This returns a future `fut` as the return value of `key.get_async`. At this moment, `key.get_async` returns. * When more than "limit" number of `_get_batcher.add()` was called, `_get_batcher` invokes its `self._todo_tasklet`, `Context._get_tasklet`, with the list of keys seen so far. * `Context._get_tasklet` fires a MultiRPC and waits on it. * Upon MultiRPC completion, `Context._get_tasklet` passes on the results to the respective `fut` from `key.get_async`. * If user calls `fut.get_result()` before "limit" number of `add()` was called, `fut.get_result()` will repeatedly call `eventloop.run1()`. * After processing immediate callbacks, `eventloop` will run idlers. `AutoBatcher._on_idle` is an idler. * `_on_idle` will run the `todo_tasklet` before the batch is full. So the engine is `todo_tasklet`, which is a proxy tasklet that can combine arguments into batches and passes along results back to respective futures. This class is mainly a helper that invokes `todo_tasklet` with the right arguments at the right time. """ def __init__(self, todo_tasklet, limit): """Init. Args: todo_tasklet: The tasklet that actually fires RPC and waits on a MultiRPC. It should take a list of (future, arg) pairs and an "options" as arguments. "options" are rpc options. limit: Max number of items to batch for each distinct value of "options". """ self._todo_tasklet = todo_tasklet self._limit = limit self._queues = {} self._running = [] self._cache = {} def __repr__(self): return '%s(%s)' % (self.__class__.__name__, self._todo_tasklet.__name__) def run_queue(self, options, todo): """Actually run the `_todo_tasklet`.""" utils.logging_debug('AutoBatcher(%s): %d items', self._todo_tasklet.__name__, len(todo)) batch_fut = self._todo_tasklet(todo, options) self._running.append(batch_fut) batch_fut.add_callback(self._finished_callback, batch_fut, todo) def _on_idle(self): """An idler eventloop can run. Eventloop calls this when it has finished processing all immediate callbacks. This method runs _todo_tasklet even before the batch is full. """ if not self.action(): return None return True def add(self, arg, options=None): """Returns back an instance of future after adding an arg. Args: arg: One argument for `_todo_tasklet`. options: RPC options. Return: An instance of future, representing the result of running `_todo_tasklet` without batching. """ fut = tasklets.Future('%s.add(%s, %s)' % (self, arg, options)) todo = self._queues.get(options) if todo is None: utils.logging_debug('AutoBatcher(%s): creating new queue for %r', self._todo_tasklet.__name__, options) if not self._queues: eventloop.add_idle(self._on_idle) todo = self._queues[options] = [] todo.append((fut, arg)) if len(todo) >= self._limit: del self._queues[options] self.run_queue(options, todo) return fut def add_once(self, arg, options=None): cache_key = (arg, options) fut = self._cache.get(cache_key) if fut is None: fut = self.add(arg, options) self._cache[cache_key] = fut fut.add_immediate_callback(self._cache.__delitem__, cache_key) return fut def action(self): queues = self._queues if not queues: return False options, todo = queues.popitem() self.run_queue(options, todo) return True def _finished_callback(self, batch_fut, todo): """Passes exception along. Args: batch_fut: the batch future returned by running todo_tasklet. todo: (fut, option) pair. fut is the future return by each add() call. If the batch fut was successful, it has already called fut.set_result() on other individual futs. This method only handles when the batch fut encountered an exception. """ self._running.remove(batch_fut) err = batch_fut.get_exception() if err is not None: tb = batch_fut.get_traceback() for (fut, _) in todo: if not fut.done(): fut.set_exception(err, tb) @tasklets.tasklet def flush(self): while self._running or self.action(): if self._running: yield self._running class Context(object): def __init__(self, conn=None, auto_batcher_class=AutoBatcher, config=None, parent_context=None): if conn is None: conn = model.make_connection(config) self._conn = conn self._auto_batcher_class = auto_batcher_class self._parent_context = parent_context max_get = (datastore_rpc.Configuration.max_get_keys(config, conn.config) or datastore_rpc.Connection.MAX_GET_KEYS) max_put = (datastore_rpc.Configuration.max_put_entities(config, conn.config) or datastore_rpc.Connection.MAX_PUT_ENTITIES) max_delete = (datastore_rpc.Configuration.max_delete_keys(config, conn.config) or datastore_rpc.Connection.MAX_DELETE_KEYS) self._get_batcher = auto_batcher_class(self._get_tasklet, max_get) self._put_batcher = auto_batcher_class(self._put_tasklet, max_put) self._delete_batcher = auto_batcher_class(self._delete_tasklet, max_delete) max_memcache = (ContextOptions.max_memcache_items(config, conn.config) or datastore_rpc.Connection.MAX_GET_KEYS) self._memcache_get_batcher = auto_batcher_class(self._memcache_get_tasklet, max_memcache) self._memcache_set_batcher = auto_batcher_class(self._memcache_set_tasklet, max_memcache) self._memcache_del_batcher = auto_batcher_class(self._memcache_del_tasklet, max_memcache) self._memcache_off_batcher = auto_batcher_class(self._memcache_off_tasklet, max_memcache) self._batchers = [self._get_batcher, self._put_batcher, self._delete_batcher, self._memcache_get_batcher, self._memcache_set_batcher, self._memcache_del_batcher, self._memcache_off_batcher, ] self._cache = {} self._memcache = memcache.Client() self._on_commit_queue = [] _memcache_prefix = b'NDB9:' @tasklets.tasklet def flush(self): more = True while more: yield [batcher.flush() for batcher in self._batchers] more = False for batcher in self._batchers: if batcher._running or batcher._queues: more = True break @tasklets.tasklet def _get_tasklet(self, todo, options): if not todo: raise RuntimeError('Nothing to do.') datastore_keys = [] for unused_fut, key in todo: datastore_keys.append(key) entities = yield self._conn.async_get(options, datastore_keys) for ent, (fut, unused_key) in zip(entities, todo): fut.set_result(ent) @tasklets.tasklet def _put_tasklet(self, todo, options): if not todo: raise RuntimeError('Nothing to do.') datastore_entities = [] for unused_fut, ent in todo: datastore_entities.append(ent) keys = yield self._conn.async_put(options, datastore_entities) for key, (fut, ent) in zip(keys, todo): if key != ent._key: if ent._has_complete_key(): ent_key = ent._key raise datastore_errors.BadKeyError( 'Entity Key differs from the one returned by Datastore. ' 'Returned Key: %r, Entity Key: %r' % (key, ent_key)) ent._key = key fut.set_result(key) @tasklets.tasklet def _delete_tasklet(self, todo, options): if not todo: raise RuntimeError('Nothing to do.') futures = [] datastore_keys = [] for fut, key in todo: futures.append(fut) datastore_keys.append(key) yield self._conn.async_delete(options, datastore_keys) for fut in futures: fut.set_result(None) @staticmethod def default_cache_policy(key): """Default cache policy. This defers to `_use_cache` on the `Model` class. Args: key: Key instance. Returns: A bool or `None`. """ flag = None if key is not None: modelclass = model.Model._kind_map.get(key.kind()) if modelclass is not None: policy = getattr(modelclass, '_use_cache', None) if policy is not None: if isinstance(policy, bool): flag = policy else: flag = policy(key) return flag _cache_policy = default_cache_policy def get_cache_policy(self): """Return the current context cache policy function. Returns: A function that accepts a `Key` instance as argument and returns a bool indicating if it should be cached. May be `None`. """ return self._cache_policy def set_cache_policy(self, func): """Set the context cache policy function. Args: func: A function that accepts a `Key` instance as argument and returns a bool indicating if it should be cached. May be `None`. """ if func is None: func = self.default_cache_policy elif isinstance(func, bool): func = lambda unused_key, flag=func: flag self._cache_policy = func def _use_cache(self, key, options=None): """Return whether to use the context cache for this key. Args: key: Key instance. options: ContextOptions instance, or None. Returns: True if the key should be cached, False otherwise. """ flag = ContextOptions.use_cache(options) if flag is None: flag = self._cache_policy(key) if flag is None: flag = ContextOptions.use_cache(self._conn.config) if flag is None: flag = True return flag @staticmethod def default_memcache_policy(key): """Default Memcache policy. This defers to `_use_memcache` on the `Model` class. Args: key: Key instance. Returns: A bool or `None`. """ flag = None if key is not None: modelclass = model.Model._kind_map.get(key.kind()) if modelclass is not None: policy = getattr(modelclass, '_use_memcache', None) if policy is not None: if isinstance(policy, bool): flag = policy else: flag = policy(key) return flag _memcache_policy = default_memcache_policy def get_memcache_policy(self): """Return the current memcache policy function. Returns: A function that accepts a `Key` instance as argument and returns a bool indicating if it should be cached. May be `None`. """ return self._memcache_policy def set_memcache_policy(self, func): """Set the memcache policy function. Args: func: A function that accepts a Key instance as argument and returns a bool indicating if it should be cached. May be None. """ if func is None: func = self.default_memcache_policy elif isinstance(func, bool): func = lambda unused_key, flag=func: flag self._memcache_policy = func def _use_memcache(self, key, options=None): """Return whether to use memcache for this key. Args: key: Key instance. options: ContextOptions instance, or None. Returns: True if the key should be cached in memcache, False otherwise. """ flag = ContextOptions.use_memcache(options) if flag is None: flag = self._memcache_policy(key) if flag is None: flag = ContextOptions.use_memcache(self._conn.config) if flag is None: flag = True return flag @staticmethod def default_datastore_policy(key): """Default Datastore policy. This defers to `_use_datastore` on the `Model` class. Args: key: Key instance. Returns: A bool or `None`. """ flag = None if key is not None: modelclass = model.Model._kind_map.get(key.kind()) if modelclass is not None: policy = getattr(modelclass, '_use_datastore', None) if policy is not None: if isinstance(policy, bool): flag = policy else: flag = policy(key) return flag _datastore_policy = default_datastore_policy def get_datastore_policy(self): """Return the current context datastore policy function. Returns: A function that accepts a `Key` instance as argument and returns a bool indicating if it should use the datastore. May be `None`. """ return self._datastore_policy def set_datastore_policy(self, func): """Set the context datastore policy function. Args: func: A function that accepts a `Key` instance as argument and returns a bool indicating if it should use the datastore. May be `None`. """ if func is None: func = self.default_datastore_policy elif isinstance(func, bool): func = lambda unused_key, flag=func: flag self._datastore_policy = func def _use_datastore(self, key, options=None): """Return whether to use the datastore for this key. Args: key: Key instance. options: ContextOptions instance, or None. Returns: True if the datastore should be used, False otherwise. """ flag = ContextOptions.use_datastore(options) if flag is None: flag = self._datastore_policy(key) if flag is None: flag = ContextOptions.use_datastore(self._conn.config) if flag is None: flag = True return flag @staticmethod def default_memcache_timeout_policy(key): """Default Memcache timeout policy. This defers to `_memcache_timeout` on the `Model` class. Args: key: Key instance. Returns: Memcache timeout to use (integer), or `None`. """ timeout = None if key is not None and isinstance(key, model.Key): modelclass = model.Model._kind_map.get(key.kind()) if modelclass is not None: policy = getattr(modelclass, '_memcache_timeout', None) if policy is not None: if isinstance(policy, six.integer_types): timeout = policy else: timeout = policy(key) return timeout _memcache_timeout_policy = default_memcache_timeout_policy def set_memcache_timeout_policy(self, func): """Set the policy function for memcache timeout (expiration). If the function returns `0`, it implies the default timeout. Args: func: A function that accepts a key instance as argument and returns an integer indicating the desired memcache timeout. May be `None`. """ if func is None: func = self.default_memcache_timeout_policy elif isinstance(func, six.integer_types): func = lambda unused_key, flag=func: flag self._memcache_timeout_policy = func def get_memcache_timeout_policy(self): """Return the current policy function for memcache timeout (expiration).""" return self._memcache_timeout_policy def _get_memcache_timeout(self, key, options=None): """Return the memcache timeout (expiration) for this key.""" timeout = ContextOptions.memcache_timeout(options) if timeout is None: timeout = self._memcache_timeout_policy(key) if timeout is None: timeout = ContextOptions.memcache_timeout(self._conn.config) if timeout is None: timeout = 0 return timeout def _get_memcache_deadline(self, options=None): """Return the memcache RPC deadline. Not to be confused with the memcache timeout, or expiration. This is only used by datastore operations when using memcache as a cache; it is ignored by the direct memcache calls. There is no way to vary this per key or per entity; you must either set it on a specific call (e.g. key.get(memcache_deadline=1) or in the configuration options of the context's connection. """ return ContextOptions.memcache_deadline(options, self._conn.config) def _load_from_cache_if_available(self, key): """Returns a cached Model instance given the entity key if available. Args: key: Key instance. Returns: A Model instance if the key exists in the cache. """ if key in self._cache: entity = self._cache[key] if entity is None or entity._key == key: raise tasklets.Return(entity) @tasklets.tasklet def get(self, key, **ctx_options): """Returns a `Model` instance given the entity key. It will use the context cache if the cache policy for the given key is enabled. Args: key: Key instance. **ctx_options: Context options. Returns: A `Model` instance if the key exists in the datastore; `None` otherwise. """ options = _make_ctx_options(ctx_options) use_cache = self._use_cache(key, options) if use_cache: self._load_from_cache_if_available(key) use_datastore = self._use_datastore(key, options) if (use_datastore and isinstance(self._conn, datastore_rpc.TransactionalConnection)): use_memcache = False else: use_memcache = self._use_memcache(key, options) ns = key.namespace() memcache_deadline = None if use_memcache: mkey = self._memcache_prefix + key.urlsafe() memcache_deadline = self._get_memcache_deadline(options) mvalue = yield self.memcache_get(mkey, for_cas=use_datastore, namespace=ns, use_cache=True, deadline=memcache_deadline) if use_cache: self._load_from_cache_if_available(key) if mvalue not in (_LOCKED, None): cls = model.Model._lookup_model(key.kind(), self._conn.adapter.default_model) pb = entity_pb2.EntityProto() try: pb.MergeFromString(mvalue) except message.DecodeError: logging.warning('Corrupt memcache entry found ' 'with key %s and namespace %s', mkey, ns) mvalue = None else: entity = cls._from_pb(pb) entity._key = key if use_cache: self._cache[key] = entity raise tasklets.Return(entity) if mvalue is None and use_datastore: yield self.memcache_set(mkey, _LOCKED, time=_LOCK_TIME, namespace=ns, use_cache=True, deadline=memcache_deadline) yield self.memcache_gets(mkey, namespace=ns, use_cache=True, deadline=memcache_deadline) if not use_datastore: raise tasklets.Return(None) if use_cache: entity = yield self._get_batcher.add_once(key, options) else: entity = yield self._get_batcher.add(key, options) if entity is not None: if use_memcache and mvalue != _LOCKED: pbs = entity._to_pb(set_key=False).SerializePartialToString() if len(pbs) <= memcache.MAX_VALUE_SIZE: timeout = self._get_memcache_timeout(key, options) yield self.memcache_cas(mkey, pbs, time=timeout, namespace=ns, deadline=memcache_deadline) if use_cache: self._cache[key] = entity raise tasklets.Return(entity) @tasklets.tasklet def put(self, entity, **ctx_options): options = _make_ctx_options(ctx_options) key = entity._key if key is None: key = model.Key(entity.__class__, None) use_datastore = self._use_datastore(key, options) use_memcache = None memcache_deadline = None if entity._has_complete_key(): use_memcache = self._use_memcache(key, options) if use_memcache: memcache_deadline = self._get_memcache_deadline(options) mkey = self._memcache_prefix + key.urlsafe() ns = key.namespace() if use_datastore: yield self.memcache_set(mkey, _LOCKED, time=_LOCK_TIME, namespace=ns, use_cache=True, deadline=memcache_deadline) else: pbs = entity._to_pb(set_key=False).SerializePartialToString() if len(pbs) > memcache.MAX_VALUE_SIZE: raise ValueError('Values may not be more than %d bytes in length; ' 'received %d bytes' % (memcache.MAX_VALUE_SIZE, len(pbs))) timeout = self._get_memcache_timeout(key, options) yield self.memcache_set(mkey, pbs, time=timeout, namespace=ns, deadline=memcache_deadline) if use_datastore: key = yield self._put_batcher.add(entity, options) if not isinstance(self._conn, datastore_rpc.TransactionalConnection): if use_memcache is None: use_memcache = self._use_memcache(key, options) if use_memcache: mkey = self._memcache_prefix + key.urlsafe() ns = key.namespace() yield self.memcache_delete(mkey, namespace=ns, deadline=memcache_deadline) if key is not None: if entity._key != key: logging.info('replacing key %s with %s', entity._key, key) entity._key = key if self._use_cache(key, options): self._cache[key] = entity raise tasklets.Return(key) @tasklets.tasklet def delete(self, key, **ctx_options): options = _make_ctx_options(ctx_options) if self._use_memcache(key, options): memcache_deadline = self._get_memcache_deadline(options) mkey = self._memcache_prefix + key.urlsafe() ns = key.namespace() yield self.memcache_set(mkey, _LOCKED, time=_LOCK_TIME, namespace=ns, use_cache=True, deadline=memcache_deadline) if self._use_datastore(key, options): yield self._delete_batcher.add(key, options) if self._use_cache(key, options): self._cache[key] = None @tasklets.tasklet def allocate_ids(self, key, size=None, max=None, **ctx_options): options = _make_ctx_options(ctx_options) lo_hi = yield self._conn.async_allocate_ids(options, key, size, max) raise tasklets.Return(lo_hi) @tasklets.tasklet def get_indexes(self, **ctx_options): options = _make_ctx_options(ctx_options) index_list = yield self._conn.async_get_indexes(options) raise tasklets.Return(index_list) @utils.positional(3) def map_query(self, query, callback, pass_batch_into_callback=None, options=None, merge_future=None): mfut = merge_future if mfut is None: mfut = tasklets.MultiFuture('map_query') @tasklets.tasklet def helper(): try: inq = tasklets.SerialQueueFuture() query.run_to_queue(inq, self._conn, options) while True: try: batch, i, ent = yield inq.getq() except EOFError: break ent = self._update_cache_from_query_result(ent, options) if ent is None: continue if callback is None: val = ent else: if pass_batch_into_callback: val = callback(batch, i, ent) else: val = callback(ent) mfut.putq(val) except GeneratorExit: raise except Exception as err: _, _, tb = sys.exc_info() mfut.set_exception(err, tb) raise else: mfut.complete() helper() return mfut def _update_cache_from_query_result(self, ent, options): if isinstance(ent, model.Key): return ent if ent._projection: return ent key = ent._key if not self._use_cache(key, options): return ent if key in self._cache: cached_ent = self._cache[key] if (cached_ent is None or cached_ent.key == key and cached_ent.__class__ is ent.__class__): return cached_ent self._cache[key] = ent return ent @utils.positional(2) def iter_query(self, query, callback=None, pass_batch_into_callback=None, options=None): return self.map_query(query, callback=callback, options=options, pass_batch_into_callback=pass_batch_into_callback, merge_future=tasklets.SerialQueueFuture()) @tasklets.tasklet def transaction(self, callback, **ctx_options): options = _make_ctx_options(ctx_options, TransactionOptions) propagation = TransactionOptions.propagation(options) if propagation is None: propagation = TransactionOptions.NESTED mode = datastore_rpc.TransactionMode.READ_WRITE if ctx_options.get('read_only', False): mode = datastore_rpc.TransactionMode.READ_ONLY parent = self if propagation == TransactionOptions.NESTED: if self.in_transaction(): raise datastore_errors.BadRequestError( 'Nested transactions are not supported.') elif propagation == TransactionOptions.MANDATORY: if not self.in_transaction(): raise datastore_errors.BadRequestError( 'Requires an existing transaction.') result = callback() if isinstance(result, tasklets.Future): result = yield result raise tasklets.Return(result) elif propagation == TransactionOptions.ALLOWED: if self.in_transaction(): result = callback() if isinstance(result, tasklets.Future): result = yield result raise tasklets.Return(result) elif propagation == TransactionOptions.INDEPENDENT: while parent.in_transaction(): parent = parent._parent_context if parent is None: raise datastore_errors.BadRequestError( 'Context without non-transactional ancestor') else: raise datastore_errors.BadArgumentError( 'Invalid propagation value (%s).' % (propagation,)) app = TransactionOptions.app(options) or key_module._DefaultAppId() retries = TransactionOptions.retries(options) if retries is None: retries = 3 yield parent.flush() transaction = None tconn = None for _ in range(1 + max(0, retries)): previous_transaction = ( transaction if mode == datastore_rpc.TransactionMode.READ_WRITE else None) transaction = yield (parent._conn.async_begin_transaction( options, app, previous_transaction, mode)) tconn = datastore_rpc.TransactionalConnection( adapter=parent._conn.adapter, config=parent._conn.config, transaction=transaction, _api_version=parent._conn._api_version) tctx = parent.__class__(conn=tconn, auto_batcher_class=parent._auto_batcher_class, parent_context=parent) tctx._old_ds_conn = datastore._GetConnection() ok = False try: tctx.set_memcache_policy(parent.get_memcache_policy()) tctx.set_memcache_timeout_policy(parent.get_memcache_timeout_policy()) tasklets.set_context(tctx) datastore._SetConnection(tconn) try: try: result = callback() if isinstance(result, tasklets.Future): result = yield result finally: yield tctx.flush() except GeneratorExit: raise except Exception: t, e, tb = sys.exc_info() tconn.async_rollback(options) if issubclass(t, datastore_errors.Rollback): return else: six.reraise(t, e, tb) else: ok = yield tconn.async_commit(options) if ok: parent._cache.update(tctx._cache) yield parent._clear_memcache(tctx._cache) raise tasklets.Return(result) finally: datastore._SetConnection(tctx._old_ds_conn) del tctx._old_ds_conn if ok: for on_commit_callback in tctx._on_commit_queue: on_commit_callback() tconn.async_rollback(options) raise datastore_errors.TransactionFailedError( 'The transaction could not be committed. Please try again.') def in_transaction(self): """Return whether a transaction is currently active.""" return isinstance(self._conn, datastore_rpc.TransactionalConnection) def call_on_commit(self, callback): """Call a callback upon successful commit of a transaction. If not in a transaction, the callback is called immediately. In a transaction, multiple callbacks may be registered and will be called once the transaction commits, in the order in which they were registered. If the transaction fails, the callbacks will not be called. If the callback raises an exception, it bubbles up normally. This means: - If the callback is called immediately, any exception it raises will bubble up immediately. - If the call is postponed until commit, remaining callbacks will be skipped and the exception will bubble up through the `transaction()` call. However, the transaction is already committed at that point. """ if not self.in_transaction(): callback() else: self._on_commit_queue.append(callback) def clear_cache(self): """Clears the in-memory cache. NOTE: This does not affect memcache. """ self._cache.clear() @tasklets.tasklet def _clear_memcache(self, keys): keys = set(key for key in keys if self._use_memcache(key)) futures = [] for key in keys: mkey = self._memcache_prefix + key.urlsafe() ns = key.namespace() fut = self.memcache_delete(mkey, namespace=ns) futures.append(fut) yield futures @tasklets.tasklet def _memcache_get_tasklet(self, todo, options): if not todo: raise RuntimeError('Nothing to do.') for_cas, namespace, deadline = options keys = set() for unused_fut, key in todo: keys.add(key) rpc = memcache.create_rpc(deadline=deadline) results = yield self._memcache.get_multi_async(keys, for_cas=for_cas, namespace=namespace, rpc=rpc) for fut, key in todo: fut.set_result(results.get(key)) @tasklets.tasklet def _memcache_set_tasklet(self, todo, options): if not todo: raise RuntimeError('Nothing to do.') opname, time, namespace, deadline = options methodname = opname + '_multi_async' method = getattr(self._memcache, methodname) mapping = {} for unused_fut, (key, value) in todo: mapping[key] = value rpc = memcache.create_rpc(deadline=deadline) results = yield method(mapping, time=time, namespace=namespace, rpc=rpc) for fut, (key, unused_value) in todo: if results is None: status = memcache.MemcacheSetResponse.ERROR else: status = results.get(key) fut.set_result(status == memcache.MemcacheSetResponse.STORED) @tasklets.tasklet def _memcache_del_tasklet(self, todo, options): if not todo: raise RuntimeError('Nothing to do.') seconds, namespace, deadline = options keys = set() for unused_fut, key in todo: keys.add(key) rpc = memcache.create_rpc(deadline=deadline) statuses = yield self._memcache.delete_multi_async(keys, seconds=seconds, namespace=namespace, rpc=rpc) status_key_mapping = {} if statuses: for key, status in zip(keys, statuses): status_key_mapping[key] = status for fut, key in todo: status = status_key_mapping.get(key, memcache.DELETE_NETWORK_FAILURE) fut.set_result(status) @tasklets.tasklet def _memcache_off_tasklet(self, todo, options): if not todo: raise RuntimeError('Nothing to do.') initial_value, namespace, deadline = options mapping = {} for unused_fut, (key, delta) in todo: mapping[key] = delta rpc = memcache.create_rpc(deadline=deadline) results = yield self._memcache.offset_multi_async( mapping, initial_value=initial_value, namespace=namespace, rpc=rpc) for fut, (key, unused_delta) in todo: fut.set_result(results.get(key)) def memcache_get(self, key, for_cas=False, namespace=None, use_cache=False, deadline=None): """An auto-batching wrapper for `memcache.get()` or `.get_multi()`. Args: key: Key to set. This must be a string; no prefix is applied. for_cas: If `True`, request and store CAS ids on the Context. namespace: Optional namespace. deadline: Optional deadline for the RPC. Returns: A Future (!) whose return value is the value retrieved from memcache, or `None`. """ if not isinstance(key, (six.text_type, six.binary_type)): raise TypeError('key must be a string; received %r' % key) if not isinstance(for_cas, bool): raise TypeError('for_cas must be a bool; received %r' % for_cas) if namespace is None: namespace = namespace_manager.get_namespace() options = (for_cas, namespace, deadline) batcher = self._memcache_get_batcher if use_cache: return batcher.add_once(key, options) else: return batcher.add(key, options) def memcache_gets(self, key, namespace=None, use_cache=False, deadline=None): return self.memcache_get(key, for_cas=True, namespace=namespace, use_cache=use_cache, deadline=deadline) def memcache_set(self, key, value, time=0, namespace=None, use_cache=False, deadline=None): if not isinstance(key, (six.text_type, six.binary_type)): raise TypeError('key must be a string; received %r' % key) if not isinstance(time, six.integer_types): raise TypeError('time must be a number; received %r' % time) if namespace is None: namespace = namespace_manager.get_namespace() options = ('set', time, namespace, deadline) batcher = self._memcache_set_batcher if use_cache: return batcher.add_once((key, value), options) else: return batcher.add((key, value), options) def memcache_add(self, key, value, time=0, namespace=None, deadline=None): if not isinstance(key, (six.text_type, six.binary_type)): raise TypeError('key must be a string; received %r' % key) if not isinstance(time, six.integer_types): raise TypeError('time must be a number; received %r' % time) if namespace is None: namespace = namespace_manager.get_namespace() return self._memcache_set_batcher.add((key, value), ('add', time, namespace, deadline)) def memcache_replace(self, key, value, time=0, namespace=None, deadline=None): if not isinstance(key, (six.text_type, six.binary_type)): raise TypeError('key must be a string; received %r' % key) if not isinstance(time, six.integer_types): raise TypeError('time must be a number; received %r' % time) if namespace is None: namespace = namespace_manager.get_namespace() options = ('replace', time, namespace, deadline) return self._memcache_set_batcher.add((key, value), options) def memcache_cas(self, key, value, time=0, namespace=None, deadline=None): if not isinstance(key, (six.text_type, six.binary_type)): raise TypeError('key must be a string; received %r' % key) if not isinstance(time, six.integer_types): raise TypeError('time must be a number; received %r' % time) if namespace is None: namespace = namespace_manager.get_namespace() return self._memcache_set_batcher.add((key, value), ('cas', time, namespace, deadline)) def memcache_delete(self, key, seconds=0, namespace=None, deadline=None): if not isinstance(key, (six.text_type, six.binary_type)): raise TypeError('key must be a string; received %r' % key) if not isinstance(seconds, six.integer_types): raise TypeError('seconds must be a number; received %r' % seconds) if namespace is None: namespace = namespace_manager.get_namespace() return self._memcache_del_batcher.add(key, (seconds, namespace, deadline)) def memcache_incr(self, key, delta=1, initial_value=None, namespace=None, deadline=None): if not isinstance(key, (six.text_type, six.binary_type)): raise TypeError('key must be a string; received %r' % key) if not isinstance(delta, six.integer_types): raise TypeError('delta must be a number; received %r' % delta) if initial_value is not None and not isinstance(initial_value, six.integer_types): raise TypeError('initial_value must be a number or None; received %r' % initial_value) if namespace is None: namespace = namespace_manager.get_namespace() return self._memcache_off_batcher.add((key, delta), (initial_value, namespace, deadline)) def memcache_decr(self, key, delta=1, initial_value=None, namespace=None, deadline=None): if not isinstance(key, (six.text_type, six.binary_type)): raise TypeError('key must be a string; received %r' % key) if not isinstance(delta, six.integer_types): raise TypeError('delta must be a number; received %r' % delta) if initial_value is not None and not isinstance(initial_value, six.integer_types): raise TypeError('initial_value must be a number or None; received %r' % initial_value) if namespace is None: namespace = namespace_manager.get_namespace() return self._memcache_off_batcher.add((key, -delta), (initial_value, namespace, deadline)) @tasklets.tasklet def urlfetch(self, url, payload=None, method='GET', headers={}, allow_truncated=False, follow_redirects=True, validate_certificate=None, deadline=None, callback=None): rpc = urlfetch.create_rpc(deadline=deadline, callback=callback) urlfetch.make_fetch_call(rpc, url, payload=payload, method=method, headers=headers, allow_truncated=allow_truncated, follow_redirects=follow_redirects, validate_certificate=validate_certificate) result = yield rpc raise tasklets.Return(result)