statefun-sdk-python/statefun/request_reply_v3.py (172 lines of code) (raw):

################################################################################ # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ################################################################################ import asyncio import typing from datetime import timedelta import statefun.context from statefun.core import parse_typename, ValueSpec, SdkAddress from statefun.messages import Message, EgressMessage from statefun.statefun_builder import StatefulFunctions, StatefulFunction # generated function protocol from statefun.request_reply_pb2 import ToFunction, FromFunction, Address, TypedValue from statefun.storage import resolve, Cell from dataclasses import dataclass @dataclass class DelayedMessage: is_cancellation: bool = None duration: int = None message: Message = None, cancellation_token: str = None class UserFacingContext(statefun.context.Context): __slots__ = ( "_self_address", "_outgoing_messages", "_outgoing_delayed_messages", "_outgoing_egress_messages", "_storage", "_caller") def __init__(self, address, storage): self._self_address = address self._outgoing_messages = [] self._outgoing_delayed_messages: typing.List[DelayedMessage] = [] self._outgoing_egress_messages = [] self._storage = storage self._caller = None @property def address(self) -> SdkAddress: """ :return: the address of the currently executing function. the address is of the form (typename, id) """ return self._self_address @property def storage(self): return self._storage @property def caller(self): return self._caller def send(self, message: Message): """ Send a message to a function. :param message: a message to send. """ self._outgoing_messages.append(message) def send_after(self, duration: timedelta, message: Message, cancellation_token: str = ""): """ Send a message to a target function after a specified delay. :param duration: the amount of time to wait before sending this message out. :param message: the message to send. :param cancellation_token: an optional cancellation token to associate with this message. """ ms = int(duration.total_seconds() * 1000.0) record = DelayedMessage(is_cancellation=False, duration=ms, message=message, cancellation_token=cancellation_token) self._outgoing_delayed_messages.append(record) def cancel_delayed_message(self, cancellation_token: str): """ Cancel a delayed message (message that was sent using send_after) with a given token. Please note that this is a best-effort operation, since the message might have been already delivered. If the message was delivered, this is a no-op operation. """ record = DelayedMessage(is_cancellation=True, cancellation_token=cancellation_token) self._outgoing_delayed_messages.append(record) def send_egress(self, message: EgressMessage): """ Send a message to an egress. :param message: the EgressMessage to send. """ self._outgoing_egress_messages.append(message) # ------------------------------------------------------------------------------------------------------------------- # Protobuf Helpers # ------------------------------------------------------------------------------------------------------------------- def sdk_address_from_pb(addr: Address) -> typing.Optional[SdkAddress]: if not addr or (not addr.namespace and not addr.type and not addr.id): return None return SdkAddress(namespace=addr.namespace, name=addr.type, id=addr.id, typename=f"{addr.namespace}/{addr.type}") # noinspection PyProtectedMember def collect_success(ctx: UserFacingContext) -> FromFunction: pb_from_function = FromFunction() pb_invocation_result = pb_from_function.invocation_result collect_messages(ctx._outgoing_messages, pb_invocation_result) collect_delayed(ctx._outgoing_delayed_messages, pb_invocation_result) collect_egress(ctx._outgoing_egress_messages, pb_invocation_result) collect_mutations(ctx._storage._cells, pb_invocation_result) return pb_from_function def collect_failure(missing_state_specs: typing.List[ValueSpec]) -> FromFunction: pb_from_function = FromFunction() incomplete_context = pb_from_function.incomplete_invocation_context missing_values = incomplete_context.missing_values for state_spec in missing_state_specs: missing_value = missing_values.add() missing_value.state_name = state_spec.name missing_value.type_typename = state_spec.type.typename protocol_expiration_spec = FromFunction.ExpirationSpec() if not state_spec.after_write and not state_spec.after_call: protocol_expiration_spec.mode = FromFunction.ExpirationSpec.ExpireMode.NONE else: protocol_expiration_spec.expire_after_millis = state_spec.duration if state_spec.after_call: protocol_expiration_spec.mode = FromFunction.ExpirationSpec.ExpireMode.AFTER_INVOKE elif state_spec.after_write: protocol_expiration_spec.mode = FromFunction.ExpirationSpec.ExpireMode.AFTER_WRITE else: raise ValueError("Unexpected state expiration mode.") missing_value.expiration_spec.CopyFrom(protocol_expiration_spec) return pb_from_function def collect_messages(messages: typing.List[Message], pb_invocation_result): pb_outgoing_messages = pb_invocation_result.outgoing_messages for message in messages: outgoing = pb_outgoing_messages.add() namespace, type = parse_typename(message.target_typename) outgoing.target.namespace = namespace outgoing.target.type = type outgoing.target.id = message.target_id outgoing.argument.CopyFrom(message.typed_value) def collect_delayed(delayed_messages: typing.List[DelayedMessage], invocation_result): delayed_invocations = invocation_result.delayed_invocations for delayed_message in delayed_messages: outgoing = delayed_invocations.add() if delayed_message.is_cancellation: # handle cancellation outgoing.cancellation_token = delayed_message.cancellation_token outgoing.is_cancellation_request = True else: message = delayed_message.message namespace, type = parse_typename(message.target_typename) outgoing.target.namespace = namespace outgoing.target.type = type outgoing.target.id = message.target_id outgoing.delay_in_ms = delayed_message.duration outgoing.argument.CopyFrom(message.typed_value) if delayed_message.cancellation_token is not None: outgoing.cancellation_token = delayed_message.cancellation_token def collect_cancellations(tokens: typing.List[str], invocation_result): outgoing_cancellations = invocation_result.outgoing_delay_cancellations for token in tokens: if token: delay_cancelltion = outgoing_cancellations.add() delay_cancelltion.cancellation_token = token def collect_egress(egresses: typing.List[EgressMessage], invocation_result): outgoing_egresses = invocation_result.outgoing_egresses for message in egresses: outgoing = outgoing_egresses.add() namespace, type = parse_typename(message.typename) outgoing.egress_namespace = namespace outgoing.egress_type = type outgoing.argument.CopyFrom(message.typed_value) def collect_mutations(cells: typing.Dict[str, Cell], invocation_result): for key, cell in cells.items(): if not cell.dirty: continue mutation = invocation_result.state_mutations.add() mutation.state_name = key val: typing.Optional[TypedValue] = cell.typed_value if val is None: # it is deleted. mutation.mutation_type = FromFunction.PersistedValueMutation.MutationType.Value('DELETE') else: mutation.mutation_type = FromFunction.PersistedValueMutation.MutationType.Value('MODIFY') mutation.state_value.CopyFrom(val) # -------------------------------------------------------------------------------------------------------------------- # The main Request Reply Handler. # -------------------------------------------------------------------------------------------------------------------- class RequestReplyHandler(object): def __init__(self, functions: StatefulFunctions): if not functions: raise ValueError("functions must be provided.") self.functions = functions def handle_sync(self, request_bytes: typing.Union[str, bytes, bytearray]) -> bytes: return asyncio.run(self.handle_async(request_bytes)) async def handle_async(self, request_bytes: typing.Union[str, bytes, bytearray]) -> bytes: # parse pb_to_function = ToFunction() pb_to_function.ParseFromString(request_bytes) # target address pb_target_address = pb_to_function.invocation.target sdk_address = sdk_address_from_pb(pb_target_address) # target stateful function target_fn: StatefulFunction = self.functions.for_typename(sdk_address.typename) if not target_fn: raise ValueError(f"Unable to find a function of type {sdk_address.typename}") # resolve state res = resolve(target_fn.storage_spec, sdk_address.typename, pb_to_function.invocation.state) if res.missing_specs: pb_from_function = collect_failure(res.missing_specs) return pb_from_function.SerializeToString() # invoke the batch ctx = UserFacingContext(sdk_address, res.storage) fun = target_fn.fun pb_batch = pb_to_function.invocation.invocations if target_fn.is_async: for pb_invocation in pb_batch: msg = Message(target_typename=sdk_address.typename, target_id=sdk_address.id, typed_value=pb_invocation.argument) ctx._caller = sdk_address_from_pb(pb_invocation.caller) # await for an async function to complete. # noinspection PyUnresolvedReferences await fun(ctx, msg) else: for pb_invocation in pb_batch: msg = Message(target_typename=sdk_address.typename, target_id=sdk_address.id, typed_value=pb_invocation.argument) ctx._caller = sdk_address_from_pb(pb_invocation.caller) # we need to call the function directly ¯\_(ツ)_/¯ fun(ctx, msg) # collect the results pb_from_function = collect_success(ctx) return pb_from_function.SerializeToString()