################################################################################
#  Licensed to the Apache Software Foundation (ASF) under one
#  or more contributor license agreements.  See the NOTICE file
#  distributed with this work for additional information
#  regarding copyright ownership.  The ASF licenses this file
#  to you under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance
#  with the License.  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import base64
import collections
from abc import ABC, abstractmethod
from apache_beam.coders import coder_impl
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.runners.worker.bundle_processor import SynchronousBagRuntimeState
from apache_beam.transforms import userstate
from enum import Enum
from functools import partial
from io import BytesIO
from typing import List, Tuple, Any, Dict, Collection, cast

from pyflink.datastream import ReduceFunction
from pyflink.datastream.functions import AggregateFunction
from pyflink.datastream.state import StateTtlConfig, MapStateDescriptor, OperatorStateStore
from pyflink.fn_execution.beam.beam_coders import FlinkCoder
from pyflink.fn_execution.coders import FieldCoder, MapCoder, from_type_info
from pyflink.fn_execution.flink_fn_execution_pb2 import StateDescriptor as pb2_StateDescriptor
from pyflink.fn_execution.internal_state import InternalKvState, N, InternalValueState, \
    InternalListState, InternalReducingState, InternalMergingState, InternalAggregatingState, \
    InternalMapState, InternalReadOnlyBroadcastState, InternalBroadcastState


class LRUCache(object):
    """
    A simple LRUCache implementation used to manage the internal runtime state.
    An internal runtime state is used to handle the data under a specific key of a "public" state.
    So the number of the internal runtime states may keep growing during the streaming task
    execution. To prevent the OOM caused by the unlimited growth, we introduce this LRUCache
    to evict the inactive internal runtime states.
    """

    def __init__(self, max_entries, default_entry):
        self._max_entries = max_entries
        self._default_entry = default_entry
        self._cache = collections.OrderedDict()
        self._on_evict = None

    def get(self, key):
        value = self._cache.pop(key, self._default_entry)
        if value != self._default_entry:
            # update the last access time
            self._cache[key] = value
        return value

    def put(self, key, value):
        self._cache[key] = value
        while len(self._cache) > self._max_entries:
            name, value = self._cache.popitem(last=False)
            if self._on_evict is not None:
                self._on_evict(name, value)

    def evict(self, key):
        value = self._cache.pop(key, self._default_entry)
        if self._on_evict is not None:
            self._on_evict(key, value)

    def evict_all(self):
        if self._on_evict is not None:
            for item in self._cache.items():
                self._on_evict(*item)
        self._cache.clear()

    def set_on_evict(self, func):
        self._on_evict = func

    def __len__(self):
        return len(self._cache)

    def __iter__(self):
        return iter(self._cache.values())

    def __contains__(self, key):
        return key in self._cache


class SynchronousKvRuntimeState(InternalKvState, ABC):
    """
    Base Class for partitioned State implementation.
    """

    def __init__(self, name: str, remote_state_backend: 'RemoteKeyedStateBackend'):
        self.name = name
        self._remote_state_backend = remote_state_backend
        self._internal_state = None
        self.namespace = None
        self._ttl_config = None
        self._cache_type = SynchronousKvRuntimeState.CacheType.ENABLE_READ_WRITE_CACHE

    def set_current_namespace(self, namespace: N) -> None:
        if namespace == self.namespace:
            return
        if self.namespace is not None:
            self._remote_state_backend.cache_internal_state(
                self._remote_state_backend._encoded_current_key, self)
        self.namespace = namespace
        self._internal_state = None

    def enable_time_to_live(self, ttl_config: StateTtlConfig):
        self._ttl_config = ttl_config
        if ttl_config.get_state_visibility() == StateTtlConfig.StateVisibility.NeverReturnExpired:
            self._cache_type = SynchronousKvRuntimeState.CacheType.DISABLE_CACHE
        elif ttl_config.get_update_type() == StateTtlConfig.UpdateType.OnReadAndWrite:
            self._cache_type = SynchronousKvRuntimeState.CacheType.ENABLE_WRITE_CACHE

        if self._cache_type != SynchronousKvRuntimeState.CacheType.ENABLE_READ_WRITE_CACHE:
            # disable read cache
            self._remote_state_backend._state_handler._state_cache._cache._max_entries = 0

    @abstractmethod
    def get_internal_state(self):
        pass

    class CacheType(Enum):
        DISABLE_CACHE = 0
        ENABLE_WRITE_CACHE = 1
        ENABLE_READ_WRITE_CACHE = 2


class SynchronousBagKvRuntimeState(SynchronousKvRuntimeState, ABC):
    """
    Base Class for State implementation backed by a :class:`SynchronousBagRuntimeState`.
    """
    def __init__(self, name: str, value_coder, remote_state_backend: 'RemoteKeyedStateBackend'):
        super(SynchronousBagKvRuntimeState, self).__init__(name, remote_state_backend)
        self._value_coder = value_coder

    def get_internal_state(self):
        if self._internal_state is None:
            self._internal_state = self._remote_state_backend._get_internal_bag_state(
                self.name, self.namespace, self._value_coder, self._ttl_config)
        return self._internal_state

    def _maybe_clear_write_cache(self):
        if self._cache_type == SynchronousKvRuntimeState.CacheType.DISABLE_CACHE or \
                self._remote_state_backend._state_cache_size <= 0:
            self._internal_state.commit()
            self._internal_state._cleared = False
            self._internal_state._added_elements = []


class SynchronousValueRuntimeState(SynchronousBagKvRuntimeState, InternalValueState):
    """
    The runtime ValueState implementation backed by a :class:`SynchronousBagRuntimeState`.
    """

    def __init__(self, name: str, value_coder, remote_state_backend: 'RemoteKeyedStateBackend'):
        super(SynchronousValueRuntimeState, self).__init__(name, value_coder, remote_state_backend)

    def value(self):
        for i in self.get_internal_state().read():
            return i
        return None

    def update(self, value) -> None:
        self.get_internal_state()
        self._internal_state.clear()
        self._internal_state.add(value)
        self._maybe_clear_write_cache()

    def clear(self) -> None:
        self.get_internal_state().clear()


class SynchronousMergingRuntimeState(SynchronousBagKvRuntimeState, InternalMergingState, ABC):
    """
    Base Class for MergingState implementation.
    """

    def __init__(self, name: str, value_coder, remote_state_backend: 'RemoteKeyedStateBackend'):
        super(SynchronousMergingRuntimeState, self).__init__(
            name, value_coder, remote_state_backend)

    def merge_namespaces(self, target: N, sources: Collection[N]) -> None:
        self._remote_state_backend.merge_namespaces(self, target, sources, self._ttl_config)


class SynchronousListRuntimeState(SynchronousMergingRuntimeState, InternalListState):
    """
    The runtime ListState implementation backed by a :class:`SynchronousBagRuntimeState`.
    """

    def __init__(self, name: str, value_coder, remote_state_backend: 'RemoteKeyedStateBackend'):
        super(SynchronousListRuntimeState, self).__init__(name, value_coder, remote_state_backend)

    def add(self, v):
        self.get_internal_state().add(v)
        self._maybe_clear_write_cache()

    def get(self):
        return self.get_internal_state().read()

    def add_all(self, values):
        self.get_internal_state()._added_elements.extend(values)
        self._maybe_clear_write_cache()

    def update(self, values):
        self.clear()
        self.add_all(values)
        self._maybe_clear_write_cache()

    def clear(self):
        self.get_internal_state().clear()


class SynchronousReducingRuntimeState(SynchronousMergingRuntimeState, InternalReducingState):
    """
    The runtime ReducingState implementation backed by a :class:`SynchronousBagRuntimeState`.
    """

    def __init__(self,
                 name: str,
                 value_coder,
                 remote_state_backend: 'RemoteKeyedStateBackend',
                 reduce_function: ReduceFunction):
        super(SynchronousReducingRuntimeState, self).__init__(
            name, value_coder, remote_state_backend)
        self._reduce_function = reduce_function

    def add(self, v):
        current_value = self.get()
        if current_value is None:
            self._internal_state.add(v)
        else:
            self._internal_state.clear()
            self._internal_state.add(self._reduce_function.reduce(current_value, v))
        self._maybe_clear_write_cache()

    def get(self):
        for i in self.get_internal_state().read():
            return i
        return None

    def clear(self):
        self.get_internal_state().clear()


class SynchronousAggregatingRuntimeState(SynchronousMergingRuntimeState, InternalAggregatingState):
    """
    The runtime AggregatingState implementation backed by a :class:`SynchronousBagRuntimeState`.
    """

    def __init__(self,
                 name: str,
                 value_coder,
                 remote_state_backend: 'RemoteKeyedStateBackend',
                 agg_function: AggregateFunction):
        super(SynchronousAggregatingRuntimeState, self).__init__(
            name, value_coder, remote_state_backend)
        self._agg_function = agg_function

    def add(self, v):
        if v is None:
            self.clear()
            return
        accumulator = self._get_accumulator()
        if accumulator is None:
            accumulator = self._agg_function.create_accumulator()
        accumulator = self._agg_function.add(v, accumulator)
        self._internal_state.clear()
        self._internal_state.add(accumulator)
        self._maybe_clear_write_cache()

    def get(self):
        accumulator = self._get_accumulator()
        if accumulator is None:
            return None
        else:
            return self._agg_function.get_result(accumulator)

    def _get_accumulator(self):
        for i in self.get_internal_state().read():
            return i
        return None

    def clear(self):
        self.get_internal_state().clear()


class CachedMapState(LRUCache):

    def __init__(self, max_entries):
        super(CachedMapState, self).__init__(max_entries, None)
        self._all_data_cached = False
        self._cached_keys = set()

        def on_evict(key, value):
            if value[0]:
                self._cached_keys.remove(key)
                self._all_data_cached = False

        self.set_on_evict(on_evict)

    def set_all_data_cached(self):
        self._all_data_cached = True

    def is_all_data_cached(self):
        return self._all_data_cached

    def put(self, key, exists_and_value):
        if exists_and_value[0]:
            self._cached_keys.add(key)
        super(CachedMapState, self).put(key, exists_and_value)

    def get_cached_keys(self):
        return self._cached_keys


class IterateType(Enum):
    ITEMS = 0
    KEYS = 1
    VALUES = 2


class IteratorToken(Enum):
    """
    The token indicates the status of current underlying iterator. It can also be a UUID,
    which represents an iterator on the Java side.
    """
    NOT_START = 0
    FINISHED = 1


def create_cache_iterator(cache_dict, iterate_type, iterated_keys=None):
    if iterated_keys is None:
        iterated_keys = []
    if iterate_type == IterateType.KEYS:
        for key, (exists, value) in cache_dict.items():
            if not exists or key in iterated_keys:
                continue
            yield key, key
    elif iterate_type == IterateType.VALUES:
        for key, (exists, value) in cache_dict.items():
            if not exists or key in iterated_keys:
                continue
            yield key, value
    elif iterate_type == IterateType.ITEMS:
        for key, (exists, value) in cache_dict.items():
            if not exists or key in iterated_keys:
                continue
            yield key, (key, value)
    else:
        raise Exception("Unsupported iterate type: %s" % iterate_type)


class CachingMapStateHandler(object):
    # GET request flags
    GET_FLAG = 0
    ITERATE_FLAG = 1
    CHECK_EMPTY_FLAG = 2
    # GET response flags
    EXIST_FLAG = 0
    IS_NONE_FLAG = 1
    NOT_EXIST_FLAG = 2
    IS_EMPTY_FLAG = 3
    NOT_EMPTY_FLAG = 4
    # APPEND request flags
    DELETE = 0
    SET_NONE = 1
    SET_VALUE = 2

    def __init__(self, caching_state_handler, max_cached_map_key_entries):
        self._state_cache = caching_state_handler._state_cache
        self._underlying = caching_state_handler._underlying
        self._context = caching_state_handler._context
        self._max_cached_map_key_entries = max_cached_map_key_entries
        self._cached_iterator_num = 0

    def _get_cache_token(self):
        if not self._state_cache.is_cache_enabled():
            return None
        if self._context.user_state_cache_token:
            return self._context.user_state_cache_token
        else:
            return self._context.bundle_cache_token

    def blocking_get(self, state_key, map_key, map_key_encoder, map_value_decoder):
        cache_token = self._get_cache_token()
        if not cache_token:
            # cache disabled / no cache token, request from remote directly
            return self._get_raw(state_key, map_key, map_key_encoder, map_value_decoder)

        # lookup cache first
        cache_state_key = self._convert_to_cache_key(state_key)
        cached_map_state = self._state_cache.peek((cache_state_key, cache_token))
        if cached_map_state is None:
            # request from remote
            exists, value = self._get_raw(state_key, map_key, map_key_encoder, map_value_decoder)
            cached_map_state = CachedMapState(self._max_cached_map_key_entries)
            cached_map_state.put(map_key, (exists, value))
            self._state_cache.put((cache_state_key, cache_token), cached_map_state)
            return exists, value
        else:
            cached_value = cached_map_state.get(map_key)
            if cached_value is None:
                if cached_map_state.is_all_data_cached():
                    return False, None

                # request from remote
                exists, value = self._get_raw(
                    state_key, map_key, map_key_encoder, map_value_decoder)
                cached_map_state.put(map_key, (exists, value))
                return exists, value
            else:
                return cached_value

    def lazy_iterator(self, state_key, iterate_type, map_key_decoder, map_value_decoder,
                      iterated_keys):
        cache_token = self._get_cache_token()
        if cache_token:
            # check if the data in the read cache can be used
            cache_state_key = self._convert_to_cache_key(state_key)
            cached_map_state = self._state_cache.peek((cache_state_key, cache_token))
            if cached_map_state and cached_map_state.is_all_data_cached():
                return create_cache_iterator(
                    cached_map_state._cache, iterate_type, iterated_keys)

        # request from remote
        last_iterator_token = IteratorToken.NOT_START
        current_batch, iterator_token = self._iterate_raw(
            state_key, iterate_type,
            last_iterator_token,
            map_key_decoder,
            map_value_decoder)

        if cache_token and \
                iterator_token == IteratorToken.FINISHED and \
                iterate_type != IterateType.KEYS and \
                self._max_cached_map_key_entries >= len(current_batch):
            # Special case: all the data of the map state is contained in current batch,
            # and can be stored in the cached map state.
            cached_map_state = CachedMapState(self._max_cached_map_key_entries)
            cache_state_key = self._convert_to_cache_key(state_key)
            for key, value in current_batch.items():
                cached_map_state.put(key, (True, value))
            cached_map_state.set_all_data_cached()
            self._state_cache.put((cache_state_key, cache_token), cached_map_state)

        return self._lazy_remote_iterator(
            state_key,
            iterate_type,
            map_key_decoder,
            map_value_decoder,
            iterated_keys,
            iterator_token,
            current_batch)

    def _lazy_remote_iterator(
            self,
            state_key,
            iterate_type,
            map_key_decoder,
            map_value_decoder,
            iterated_keys,
            iterator_token,
            current_batch):
        if iterate_type == IterateType.KEYS:
            while True:
                for key in current_batch:
                    if key in iterated_keys:
                        continue
                    yield key, key
                if iterator_token == IteratorToken.FINISHED:
                    break
                current_batch, iterator_token = self._iterate_raw(
                    state_key,
                    iterate_type,
                    iterator_token,
                    map_key_decoder,
                    map_value_decoder)
        elif iterate_type == IterateType.VALUES:
            while True:
                for key, value in current_batch.items():
                    if key in iterated_keys:
                        continue
                    yield key, value
                if iterator_token == IteratorToken.FINISHED:
                    break
                current_batch, iterator_token = self._iterate_raw(
                    state_key,
                    iterate_type,
                    iterator_token,
                    map_key_decoder,
                    map_value_decoder)
        elif iterate_type == IterateType.ITEMS:
            while True:
                for key, value in current_batch.items():
                    if key in iterated_keys:
                        continue
                    yield key, (key, value)
                if iterator_token == IteratorToken.FINISHED:
                    break
                current_batch, iterator_token = self._iterate_raw(
                    state_key,
                    iterate_type,
                    iterator_token,
                    map_key_decoder,
                    map_value_decoder)
        else:
            raise Exception("Unsupported iterate type: %s" % iterate_type)

    def extend(self, state_key, items: List[Tuple[int, Any, Any]],
               map_key_encoder, map_value_encoder):
        cache_token = self._get_cache_token()
        if cache_token:
            # Cache lookup
            cache_state_key = self._convert_to_cache_key(state_key)
            cached_map_state = self._state_cache.peek((cache_state_key, cache_token))
            if cached_map_state is None:
                cached_map_state = CachedMapState(self._max_cached_map_key_entries)
                self._state_cache.put((cache_state_key, cache_token), cached_map_state)
            for request_flag, map_key, map_value in items:
                if request_flag == self.DELETE:
                    cached_map_state.put(map_key, (False, None))
                elif request_flag == self.SET_NONE:
                    cached_map_state.put(map_key, (True, None))
                elif request_flag == self.SET_VALUE:
                    cached_map_state.put(map_key, (True, map_value))
                else:
                    raise Exception("Unknown flag: " + str(request_flag))
        return self._append_raw(
            state_key,
            items,
            map_key_encoder,
            map_value_encoder)

    def check_empty(self, state_key):
        cache_token = self._get_cache_token()
        if cache_token:
            # Cache lookup
            cache_state_key = self._convert_to_cache_key(state_key)
            cached_map_state = self._state_cache.peek((cache_state_key, cache_token))
            if cached_map_state is not None:
                if cached_map_state.is_all_data_cached() and \
                        len(cached_map_state.get_cached_keys()) == 0:
                    return True
                elif len(cached_map_state.get_cached_keys()) > 0:
                    return False
        return self._check_empty_raw(state_key)

    def clear(self, state_key):
        self.clear_read_cache(state_key)
        return self._underlying.clear(state_key)

    def clear_read_cache(self, state_key):
        cache_token = self._get_cache_token()
        if cache_token:
            cache_key = self._convert_to_cache_key(state_key)
            self._state_cache.invalidate((cache_key, cache_token))

    def get_cached_iterators_num(self):
        return self._cached_iterator_num

    def _inc_cached_iterators_num(self):
        self._cached_iterator_num += 1

    def _dec_cached_iterators_num(self):
        self._cached_iterator_num -= 1

    def reset_cached_iterators_num(self):
        self._cached_iterator_num = 0

    def _check_empty_raw(self, state_key):
        output_stream = coder_impl.create_OutputStream()
        output_stream.write_byte(self.CHECK_EMPTY_FLAG)
        continuation_token = output_stream.get()
        data, response_token = self._underlying.get_raw(state_key, continuation_token)
        if data[0] == self.IS_EMPTY_FLAG:
            return True
        elif data[0] == self.NOT_EMPTY_FLAG:
            return False
        else:
            raise Exception("Unknown response flag: " + str(data[0]))

    def _get_raw(self, state_key, map_key, map_key_encoder, map_value_decoder):
        output_stream = coder_impl.create_OutputStream()
        output_stream.write_byte(self.GET_FLAG)
        map_key_encoder(map_key, output_stream)
        continuation_token = output_stream.get()
        data, response_token = self._underlying.get_raw(state_key, continuation_token)
        input_stream = coder_impl.create_InputStream(data)
        result_flag = input_stream.read_byte()
        if result_flag == self.EXIST_FLAG:
            return True, map_value_decoder(input_stream)
        elif result_flag == self.IS_NONE_FLAG:
            return True, None
        elif result_flag == self.NOT_EXIST_FLAG:
            return False, None
        else:
            raise Exception("Unknown response flag: " + str(result_flag))

    def _iterate_raw(self, state_key, iterate_type, iterator_token,
                     map_key_decoder, map_value_decoder):
        output_stream = coder_impl.create_OutputStream()
        output_stream.write_byte(self.ITERATE_FLAG)
        output_stream.write_byte(iterate_type.value)
        if not isinstance(iterator_token, IteratorToken):
            # The iterator token represents a Java iterator
            output_stream.write_bigendian_int32(len(iterator_token))
            output_stream.write(iterator_token)
        else:
            output_stream.write_bigendian_int32(0)
        continuation_token = output_stream.get()
        data, response_token = self._underlying.get_raw(state_key, continuation_token)
        if len(response_token) != 0:
            # The new iterator token is an UUID which represents a cached iterator at Java
            # side.
            new_iterator_token = response_token
            if iterator_token == IteratorToken.NOT_START:
                # This is the first request but not the last request of current state.
                # It means there is a new iterator has been created and cached at Java side.
                self._inc_cached_iterators_num()
        else:
            new_iterator_token = IteratorToken.FINISHED
            if iterator_token != IteratorToken.NOT_START:
                # This is not the first request but the last request of current state.
                # It means the cached iterator created at Java side has been removed as
                # current iteration has finished.
                self._dec_cached_iterators_num()
        input_stream = coder_impl.create_InputStream(data)
        if iterate_type == IterateType.ITEMS or iterate_type == IterateType.VALUES:
            # decode both key and value
            current_batch = {}
            while input_stream.size() > 0:
                key = map_key_decoder(input_stream)
                is_not_none = input_stream.read_byte()
                if is_not_none:
                    value = map_value_decoder(input_stream)
                else:
                    value = None
                current_batch[key] = value
        else:
            # only decode key
            current_batch = []
            while input_stream.size() > 0:
                key = map_key_decoder(input_stream)
                current_batch.append(key)
        return current_batch, new_iterator_token

    def _append_raw(self, state_key, items, map_key_encoder, map_value_encoder):
        output_stream = coder_impl.create_OutputStream()
        output_stream.write_bigendian_int32(len(items))
        for request_flag, map_key, map_value in items:
            output_stream.write_byte(request_flag)
            # Not all the coder impls will serialize the length of bytes when we set the "nested"
            # param to "True", so we need to encode the length of bytes manually.
            tmp_out = coder_impl.create_OutputStream()
            map_key_encoder(map_key, tmp_out)
            serialized_data = tmp_out.get()
            output_stream.write_bigendian_int32(len(serialized_data))
            output_stream.write(serialized_data)
            if request_flag == self.SET_VALUE:
                tmp_out = coder_impl.create_OutputStream()
                map_value_encoder(map_value, tmp_out)
                serialized_data = tmp_out.get()
                output_stream.write_bigendian_int32(len(serialized_data))
                output_stream.write(serialized_data)
        return self._underlying.append_raw(state_key, output_stream.get())

    @staticmethod
    def _convert_to_cache_key(state_key):
        return state_key.SerializeToString()


class RemovableConcatIterator(collections.abc.Iterator):

    def __init__(self, internal_map_state, first, second):
        self._first = first
        self._second = second
        self._first_not_finished = True
        self._internal_map_state = internal_map_state
        self._mod_count = self._internal_map_state._mod_count
        self._last_key = None

    def __next__(self):
        self._check_modification()
        if self._first_not_finished:
            try:
                self._last_key, element = next(self._first)
                return element
            except StopIteration:
                self._first_not_finished = False
                return self.__next__()
        else:
            self._last_key, element = next(self._second)
            return element

    def remove(self):
        """
        Remove the last element returned by this iterator.
        """
        if self._last_key is None:
            raise Exception("You need to call the '__next__' method before calling "
                            "this method.")
        self._check_modification()
        # Bypass the 'remove' method of the map state to avoid triggering the commit of the write
        # cache.
        if self._internal_map_state._cleared:
            del self._internal_map_state._write_cache[self._last_key]
            if len(self._internal_map_state._write_cache) == 0:
                self._internal_map_state._is_empty = True
        else:
            self._internal_map_state._write_cache[self._last_key] = (False, None)
        self._mod_count += 1
        self._internal_map_state._mod_count += 1
        self._last_key = None

    def _check_modification(self):
        if self._mod_count != self._internal_map_state._mod_count:
            raise Exception("Concurrent modification detected. "
                            "You can not modify the map state when iterating it except using the "
                            "'remove' method of this iterator.")


class InternalSynchronousMapRuntimeState(object):

    def __init__(self,
                 map_state_handler: CachingMapStateHandler,
                 state_key,
                 map_key_coder,
                 map_value_coder,
                 max_write_cache_entries):
        self._map_state_handler = map_state_handler
        self._state_key = state_key
        self._map_key_coder = map_key_coder
        if isinstance(map_key_coder, FieldCoder):
            map_key_coder_impl = FlinkCoder(map_key_coder).get_impl()
        else:
            map_key_coder_impl = map_key_coder.get_impl()
        self._map_key_encoder, self._map_key_decoder = \
            self._get_encoder_and_decoder(map_key_coder_impl)
        self._map_value_coder = map_value_coder
        if isinstance(map_value_coder, FieldCoder):
            map_value_coder_impl = FlinkCoder(map_value_coder).get_impl()
        else:
            map_value_coder_impl = map_value_coder.get_impl()
        self._map_value_encoder, self._map_value_decoder = \
            self._get_encoder_and_decoder(map_value_coder_impl)
        self._write_cache = dict()
        self._max_write_cache_entries = max_write_cache_entries
        self._is_empty = None
        self._cleared = False
        self._mod_count = 0

    def get(self, map_key):
        if self._is_empty:
            return None
        if map_key in self._write_cache:
            exists, value = self._write_cache[map_key]
            if exists:
                return value
            else:
                return None
        if self._cleared:
            return None
        exists, value = self._map_state_handler.blocking_get(
            self._state_key, map_key, self._map_key_encoder, self._map_value_decoder)
        if exists:
            return value
        else:
            return None

    def put(self, map_key, map_value):
        self._write_cache[map_key] = (True, map_value)
        self._is_empty = False
        self._mod_count += 1
        if len(self._write_cache) >= self._max_write_cache_entries:
            self.commit()

    def put_all(self, dict_value):
        for map_key, map_value in dict_value:
            self._write_cache[map_key] = (True, map_value)
        self._is_empty = False
        self._mod_count += 1
        if len(self._write_cache) >= self._max_write_cache_entries:
            self.commit()

    def remove(self, map_key):
        if self._is_empty:
            return
        if self._cleared:
            del self._write_cache[map_key]
            if len(self._write_cache) == 0:
                self._is_empty = True
        else:
            self._write_cache[map_key] = (False, None)
            self._is_empty = None
        self._mod_count += 1
        if len(self._write_cache) >= self._max_write_cache_entries:
            self.commit()

    def contains(self, map_key):
        if self._is_empty:
            return False
        if self.get(map_key) is None:
            return False
        else:
            return True

    def is_empty(self):
        if self._is_empty is None:
            if len(self._write_cache) > 0:
                self.commit()
            self._is_empty = self._map_state_handler.check_empty(self._state_key)
        return self._is_empty

    def clear(self):
        self._cleared = True
        self._is_empty = True
        self._mod_count += 1
        self._write_cache.clear()

    def items(self):
        return RemovableConcatIterator(
            self,
            self.write_cache_iterator(IterateType.ITEMS),
            self.remote_data_iterator(IterateType.ITEMS))

    def keys(self):
        return RemovableConcatIterator(
            self,
            self.write_cache_iterator(IterateType.KEYS),
            self.remote_data_iterator(IterateType.KEYS))

    def values(self):
        return RemovableConcatIterator(
            self,
            self.write_cache_iterator(IterateType.VALUES),
            self.remote_data_iterator(IterateType.VALUES))

    def commit(self):
        to_await = None
        if self._cleared:
            to_await = self._map_state_handler.clear(self._state_key)
        if self._write_cache:
            append_items = []
            for map_key, (exists, value) in self._write_cache.items():
                if exists:
                    if value is not None:
                        append_items.append(
                            (CachingMapStateHandler.SET_VALUE, map_key, value))
                    else:
                        append_items.append((CachingMapStateHandler.SET_NONE, map_key, None))
                else:
                    append_items.append((CachingMapStateHandler.DELETE, map_key, None))
            self._write_cache.clear()
            to_await = self._map_state_handler.extend(
                self._state_key, append_items, self._map_key_encoder, self._map_value_encoder)
        if to_await:
            to_await.get()
        self._write_cache.clear()
        self._cleared = False
        self._mod_count += 1

    def write_cache_iterator(self, iterate_type):
        return create_cache_iterator(self._write_cache, iterate_type)

    def remote_data_iterator(self, iterate_type):
        if self._cleared or self._is_empty:
            return iter([])
        else:
            return self._map_state_handler.lazy_iterator(
                self._state_key,
                iterate_type,
                self._map_key_decoder,
                self._map_value_decoder,
                self._write_cache)

    @staticmethod
    def _get_encoder_and_decoder(coder):
        encoder = partial(coder.encode_to_stream, nested=True)
        decoder = partial(coder.decode_from_stream, nested=True)
        return encoder, decoder


class SynchronousMapRuntimeState(SynchronousKvRuntimeState, InternalMapState):

    def __init__(self,
                 name: str,
                 map_key_coder,
                 map_value_coder,
                 remote_state_backend: 'RemoteKeyedStateBackend'):
        super(SynchronousMapRuntimeState, self).__init__(name, remote_state_backend)
        self._map_key_coder = map_key_coder
        self._map_value_coder = map_value_coder

    def get_internal_state(self):
        if self._internal_state is None:
            self._internal_state = self._remote_state_backend._get_internal_map_state(
                self.name,
                self.namespace,
                self._map_key_coder,
                self._map_value_coder,
                self._ttl_config,
                self._cache_type)
        return self._internal_state

    def get(self, key):
        return self.get_internal_state().get(key)

    def put(self, key, value):
        self.get_internal_state().put(key, value)

    def put_all(self, dict_value):
        self.get_internal_state().put_all(dict_value)

    def remove(self, key):
        self.get_internal_state().remove(key)

    def contains(self, key):
        return self.get_internal_state().contains(key)

    def items(self):
        return self.get_internal_state().items()

    def keys(self):
        return self.get_internal_state().keys()

    def values(self):
        return self.get_internal_state().values()

    def is_empty(self):
        return self.get_internal_state().is_empty()

    def clear(self):
        self.get_internal_state().clear()


class RemoteKeyedStateBackend(object):
    """
    A keyed state backend provides methods for managing keyed state.
    """

    MERGE_NAMESAPCES_MARK = "merge_namespaces"

    def __init__(self,
                 state_handler,
                 key_coder,
                 namespace_coder,
                 state_cache_size,
                 map_state_read_cache_size,
                 map_state_write_cache_size):
        self._state_handler = state_handler
        self._map_state_handler = CachingMapStateHandler(
            state_handler, map_state_read_cache_size)
        self._key_coder_impl = key_coder.get_impl()
        self.namespace_coder = namespace_coder
        if namespace_coder:
            self._namespace_coder_impl = namespace_coder.get_impl()
        else:
            self._namespace_coder_impl = None
        self._state_cache_size = state_cache_size
        self._map_state_write_cache_size = map_state_write_cache_size
        self._all_states = {}  # type: Dict[str, SynchronousKvRuntimeState]
        self._internal_state_cache = LRUCache(self._state_cache_size, None)
        self._internal_state_cache.set_on_evict(
            lambda key, value: self.commit_internal_state(value))
        self._current_key = None
        self._encoded_current_key = None
        self._clear_iterator_mark = beam_fn_api_pb2.StateKey(
            multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
                transform_id="clear_iterators",
                side_input_id="clear_iterators",
                key=self._encoded_current_key))

    def get_list_state(self, name, element_coder, ttl_config=None):
        return self._wrap_internal_bag_state(
            name,
            element_coder,
            SynchronousListRuntimeState,
            SynchronousListRuntimeState,
            ttl_config)

    def get_value_state(self, name, value_coder, ttl_config=None):
        return self._wrap_internal_bag_state(
            name,
            value_coder,
            SynchronousValueRuntimeState,
            SynchronousValueRuntimeState,
            ttl_config)

    def get_map_state(self, name, map_key_coder, map_value_coder, ttl_config=None):
        if name in self._all_states:
            self.validate_map_state(name, map_key_coder, map_value_coder)
            return self._all_states[name]
        map_state = SynchronousMapRuntimeState(name, map_key_coder, map_value_coder, self)
        if ttl_config is not None:
            map_state.enable_time_to_live(ttl_config)
        self._all_states[name] = map_state
        return map_state

    def get_reducing_state(self, name, coder, reduce_function, ttl_config=None):
        return self._wrap_internal_bag_state(
            name,
            coder,
            SynchronousReducingRuntimeState,
            partial(SynchronousReducingRuntimeState, reduce_function=reduce_function),
            ttl_config)

    def get_aggregating_state(self, name, coder, agg_function, ttl_config=None):
        return self._wrap_internal_bag_state(
            name,
            coder,
            SynchronousAggregatingRuntimeState,
            partial(SynchronousAggregatingRuntimeState, agg_function=agg_function),
            ttl_config)

    def validate_state(self, name, coder, expected_type):
        if name in self._all_states:
            state = self._all_states[name]
            if not isinstance(state, expected_type):
                raise Exception("The state name '%s' is already in use and not a %s."
                                % (name, expected_type))
            if state._value_coder != coder:
                raise Exception("State name corrupted: %s" % name)

    def validate_map_state(self, name, map_key_coder, map_value_coder):
        if name in self._all_states:
            state = self._all_states[name]
            if not isinstance(state, SynchronousMapRuntimeState):
                raise Exception("The state name '%s' is already in use and not a map state."
                                % name)
            if state._map_key_coder != map_key_coder or \
                    state._map_value_coder != map_value_coder:
                raise Exception("State name corrupted: %s" % name)

    def _wrap_internal_bag_state(
            self, name, element_coder, wrapper_type, wrap_method, ttl_config):
        if name in self._all_states:
            self.validate_state(name, element_coder, wrapper_type)
            return self._all_states[name]
        wrapped_state = wrap_method(name, element_coder, self)
        if ttl_config is not None:
            wrapped_state.enable_time_to_live(ttl_config)
        self._all_states[name] = wrapped_state
        return wrapped_state

    def _get_internal_bag_state(self, name, namespace, element_coder, ttl_config):
        encoded_namespace = self._encode_namespace(namespace)
        cached_state = self._internal_state_cache.get(
            (name, self._encoded_current_key, encoded_namespace))
        if cached_state is not None:
            return cached_state
        # The created internal state would not be put into the internal state cache
        # at once. The internal state cache is only updated when the current key changes.
        # The reason is that the state cache size may be smaller that the count of activated
        # state (i.e. the state with current key).
        if isinstance(element_coder, FieldCoder):
            element_coder = FlinkCoder(element_coder)
        state_spec = userstate.BagStateSpec(name, element_coder)
        internal_state = self._create_bag_state(state_spec, encoded_namespace, ttl_config)
        return internal_state

    def _get_internal_map_state(
            self, name, namespace, map_key_coder, map_value_coder, ttl_config, cache_type):
        encoded_namespace = self._encode_namespace(namespace)
        cached_state = self._internal_state_cache.get(
            (name, self._encoded_current_key, encoded_namespace))
        if cached_state is not None:
            return cached_state
        internal_map_state = self._create_internal_map_state(
            name, encoded_namespace, map_key_coder, map_value_coder, ttl_config, cache_type)
        return internal_map_state

    def _create_bag_state(self, state_spec: userstate.StateSpec, encoded_namespace, ttl_config) \
            -> userstate.AccumulatingRuntimeState:
        if isinstance(state_spec, userstate.BagStateSpec):
            bag_state = SynchronousBagRuntimeState(
                self._state_handler,
                state_key=self.get_bag_state_key(
                    state_spec.name, self._encoded_current_key, encoded_namespace, ttl_config),
                value_coder=state_spec.coder)
            return bag_state
        else:
            raise NotImplementedError(state_spec)

    def _create_internal_map_state(
            self, name, encoded_namespace, map_key_coder, map_value_coder, ttl_config, cache_type):
        # Currently the `beam_fn_api.proto` does not support MapState, so we use the
        # the `MultimapSideInput` message to mark the state as a MapState for now.
        state_proto = pb2_StateDescriptor()
        state_proto.state_name = name
        if ttl_config is not None:
            state_proto.state_ttl_config.CopyFrom(ttl_config._to_proto())
        state_key = beam_fn_api_pb2.StateKey(
            multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
                transform_id="",
                window=encoded_namespace,
                side_input_id=base64.b64encode(state_proto.SerializeToString()),
                key=self._encoded_current_key))
        if cache_type == SynchronousKvRuntimeState.CacheType.DISABLE_CACHE:
            write_cache_size = 0
        else:
            write_cache_size = self._map_state_write_cache_size
        return InternalSynchronousMapRuntimeState(
            self._map_state_handler,
            state_key,
            map_key_coder,
            map_value_coder,
            write_cache_size)

    def _encode_namespace(self, namespace):
        if namespace is not None:
            encoded_namespace = self._namespace_coder_impl.encode(namespace)
        else:
            encoded_namespace = b''
        return encoded_namespace

    def cache_internal_state(self, encoded_key, internal_kv_state: SynchronousKvRuntimeState):
        encoded_old_namespace = self._encode_namespace(internal_kv_state.namespace)
        self._internal_state_cache.put(
            (internal_kv_state.name, encoded_key, encoded_old_namespace),
            internal_kv_state.get_internal_state())

    def set_current_key(self, key):
        if key == self._current_key:
            return
        encoded_old_key = self._encoded_current_key
        for state_name, state_obj in self._all_states.items():
            if self._state_cache_size > 0:
                # cache old internal state
                self.cache_internal_state(encoded_old_key, state_obj)
            state_obj.namespace = None
            state_obj._internal_state = None
        self._current_key = key
        self._encoded_current_key = self._key_coder_impl.encode(self._current_key)

    def get_current_key(self):
        return self._current_key

    def commit(self):
        for internal_state in self._internal_state_cache:
            self.commit_internal_state(internal_state)
        for name, state in self._all_states.items():
            if (name, self._encoded_current_key, self._encode_namespace(state.namespace)) \
                    not in self._internal_state_cache:
                self.commit_internal_state(state._internal_state)

    def clear_cached_iterators(self):
        if self._map_state_handler.get_cached_iterators_num() > 0:
            self._clear_iterator_mark.multimap_side_input.key = self._encoded_current_key
            self._map_state_handler.clear(self._clear_iterator_mark)

    def merge_namespaces(self, state: SynchronousMergingRuntimeState, target, sources, ttl_config):
        for source in sources:
            state.set_current_namespace(source)
            self.commit_internal_state(state.get_internal_state())
        state.set_current_namespace(target)
        self.commit_internal_state(state.get_internal_state())
        encoded_target_namespace = self._encode_namespace(target)
        encoded_namespaces = [self._encode_namespace(source) for source in sources]
        self.clear_state_cache(state, [encoded_target_namespace] + encoded_namespaces)

        state_key = self.get_bag_state_key(
            state.name, self._encoded_current_key, encoded_target_namespace, ttl_config)
        state_key.bag_user_state.transform_id = self.MERGE_NAMESAPCES_MARK

        encoded_namespaces_writer = BytesIO()
        encoded_namespaces_writer.write(len(sources).to_bytes(4, 'big'))
        for encoded_namespace in encoded_namespaces:
            encoded_namespaces_writer.write(encoded_namespace)
        sources_bytes = encoded_namespaces_writer.getvalue()
        to_await = self._map_state_handler._underlying.append_raw(state_key, sources_bytes)
        if to_await:
            to_await.get()

    def clear_state_cache(self, state: SynchronousMergingRuntimeState, encoded_namespaces):
        name = state.name
        for encoded_namespace in encoded_namespaces:
            if (name, self._encoded_current_key, encoded_namespace) in self._internal_state_cache:
                # commit and clear the write cache
                self._internal_state_cache.evict(
                    (name, self._encoded_current_key, encoded_namespace))
                # currently all the SynchronousMergingRuntimeState is based on bag state
                state_key = self.get_bag_state_key(
                    name, self._encoded_current_key, encoded_namespace, None)
                # clear the read cache, the read cache is shared between map state handler and bag
                # state handler. So we can use the map state handler instead.
                self._map_state_handler.clear_read_cache(state_key)

    def get_bag_state_key(self, name, encoded_key, encoded_namespace, ttl_config):
        from pyflink.fn_execution.flink_fn_execution_pb2 import StateDescriptor
        state_proto = StateDescriptor()
        state_proto.state_name = name
        if ttl_config is not None:
            state_proto.state_ttl_config.CopyFrom(ttl_config._to_proto())
        return beam_fn_api_pb2.StateKey(
            bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
                transform_id="",
                window=encoded_namespace,
                user_state_id=base64.b64encode(state_proto.SerializeToString()),
                key=encoded_key))

    @staticmethod
    def commit_internal_state(internal_state):
        if internal_state is not None:
            internal_state.commit()
        # reset the status of the internal state to reuse the object cross bundle
        if isinstance(internal_state, SynchronousBagRuntimeState):
            internal_state._cleared = False
            internal_state._added_elements = []


class SynchronousReadOnlyBroadcastRuntimeState(InternalReadOnlyBroadcastState):
    def __init__(self, name: str, internal_map_state: "InternalSynchronousMapRuntimeState"):
        self._name = name
        self._internal_map_state = internal_map_state

    def get(self, key):
        return self._internal_map_state.get(key)

    def contains(self, key) -> bool:
        return self._internal_map_state.contains(key)

    def items(self):
        return self._internal_map_state.items()

    def keys(self):
        return self._internal_map_state.keys()

    def values(self):
        return self._internal_map_state.values()

    def is_empty(self):
        return self._internal_map_state.is_empty()

    def clear(self):
        return self._internal_map_state.clear()


class SynchronousBroadcastRuntimeState(
    SynchronousReadOnlyBroadcastRuntimeState, InternalBroadcastState
):
    def __init__(self, name: str, internal_map_state: "InternalSynchronousMapRuntimeState"):
        super(SynchronousBroadcastRuntimeState, self).__init__(name, internal_map_state)

    def put(self, key, value):
        self._internal_map_state.put(key, value)

    def put_all(self, dict_value):
        self._internal_map_state.put_all(dict_value)

    def remove(self, key):
        self._internal_map_state.remove(key)

    def commit(self):
        self._internal_map_state.commit()

    def to_read_only_broadcast_state(self) -> "SynchronousReadOnlyBroadcastRuntimeState":
        return SynchronousReadOnlyBroadcastRuntimeState(self._name, self._internal_map_state)


class OperatorStateBackend(OperatorStateStore, ABC):

    @abstractmethod
    def commit(self):
        pass


class RemoteOperatorStateBackend(OperatorStateBackend):
    def __init__(
        self, state_handler, state_cache_size, map_state_read_cache_size, map_state_write_cache_size
    ):
        self._state_handler = state_handler
        self._state_cache_size = state_cache_size
        # NOTE: if user stores a state into a class member, that state actually won't be actually
        # evicted from memory (because its counter > 0)
        self._state_cache = LRUCache(state_cache_size, None)
        self._state_cache.set_on_evict(lambda _, state: state.commit())
        self._map_state_read_cache_size = map_state_read_cache_size
        self._map_state_write_cache_size = map_state_write_cache_size
        self._map_state_handler = CachingMapStateHandler(state_handler, map_state_read_cache_size)

    def get_broadcast_state(
        self, state_descriptor: MapStateDescriptor
    ) -> 'SynchronousBroadcastRuntimeState':
        state_name = state_descriptor.name
        map_coder = cast(MapCoder, from_type_info(state_descriptor.type_info))  # type: MapCoder
        key_coder = map_coder._key_coder
        value_coder = map_coder._value_coder

        if state_name in self._state_cache:
            self._validate_broadcast_state(state_name, key_coder, value_coder)
            return self._state_cache.get(state_name)

        state_proto = pb2_StateDescriptor()
        state_proto.state_name = state_name
        # Currently, MultimapKeysSideInput is used for BroadcastState
        state_key = beam_fn_api_pb2.StateKey(
            multimap_keys_side_input=beam_fn_api_pb2.StateKey.MultimapKeysSideInput(
                transform_id="",
                window=bytes(),
                side_input_id=base64.b64encode(state_proto.SerializeToString()),
            )
        )

        internal_map_state = InternalSynchronousMapRuntimeState(
            self._map_state_handler,
            state_key,
            key_coder,
            value_coder,
            self._map_state_write_cache_size,
        )

        broadcast_state = SynchronousBroadcastRuntimeState(
            state_descriptor.name, internal_map_state
        )
        self._state_cache.put(state_name, broadcast_state)
        return broadcast_state

    def commit(self):
        for state in self._state_cache:
            cast(SynchronousBroadcastRuntimeState, state).commit()

    def _validate_broadcast_state(self, name, key_coder, value_coder):
        if name in self._state_cache:
            state = cast(SynchronousBroadcastRuntimeState, self._state_cache.get(name))
            if (
                key_coder != state._internal_map_state._map_key_coder
                or value_coder != state._internal_map_state._map_value_coder
            ):
                raise Exception("State name corrupted: %s" % name)
