sdks/python/apache_beam/internal/set_pickler.py (95 lines of code) (raw):

# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Custom pickling logic for sets to make the serialization semi-deterministic. To make set serialization semi-deterministic, we must pick an order for the set elements. Sets may contain elements of types not defining a comparison "<" operator. To provide an order, we define our own custom comparison function which supports elements of near-arbitrary types and use that to sort the contents of each set during serialization. Attempts at determinism are made on a best-effort basis to improve hit rates for cached workflows and the ordering does not define a total order for all values. """ import enum import functools def compare(lhs, rhs): """Returns -1, 0, or 1 depending on whether lhs <, =, or > rhs.""" if lhs < rhs: return -1 elif lhs > rhs: return 1 else: return 0 def generic_object_comparison(lhs, rhs, lhs_path, rhs_path, max_depth): """Identifies which object goes first in an (almost) total order of objects. Args: lhs: An arbitrary Python object or built-in type. rhs: An arbitrary Python object or built-in type. lhs_path: Traversal path from the root lhs object up to, but not including, lhs. The original contents of lhs_path are restored before the function returns. rhs_path: Same as lhs_path except for the rhs. max_depth: Maximum recursion depth. Returns: -1, 0, or 1 depending on whether lhs or rhs goes first in the total order. 0 if max_depth is exhausted. 0 if lhs is in lhs_path or rhs is in rhs_path (there is a cycle). """ if id(lhs) == id(rhs): # Fast path return 0 if type(lhs) != type(rhs): return compare(str(type(lhs)), str(type(rhs))) if type(lhs) in [int, float, bool, str, bool, bytes, bytearray]: return compare(lhs, rhs) if isinstance(lhs, enum.Enum): # Enums can have values with arbitrary types. The names are strings. return compare(lhs.name, rhs.name) # To avoid exceeding the recursion depth limit, set a limit on recursion. max_depth -= 1 if max_depth < 0: return 0 # Check for cycles in the traversal path to avoid getting stuck in a loop. if id(lhs) in lhs_path or id(rhs) in rhs_path: return 0 lhs_path.append(id(lhs)) rhs_path.append(id(rhs)) # The comparison logic is split across two functions to simplifying updating # and restoring the traversal paths. result = _generic_object_comparison_recursive_path( lhs, rhs, lhs_path, rhs_path, max_depth) lhs_path.pop() rhs_path.pop() return result def _generic_object_comparison_recursive_path( lhs, rhs, lhs_path, rhs_path, max_depth): if type(lhs) == tuple or type(lhs) == list: result = compare(len(lhs), len(rhs)) if result != 0: return result for i in range(len(lhs)): result = generic_object_comparison( lhs[i], rhs[i], lhs_path, rhs_path, max_depth) if result != 0: return result return 0 if type(lhs) == frozenset or type(lhs) == set: return generic_object_comparison( tuple(sort_if_possible(lhs, lhs_path, rhs_path, max_depth)), tuple(sort_if_possible(rhs, lhs_path, rhs_path, max_depth)), lhs_path, rhs_path, max_depth) if type(lhs) == dict: lhs_keys = list(lhs.keys()) rhs_keys = list(rhs.keys()) result = compare(len(lhs_keys), len(rhs_keys)) if result != 0: return result lhs_keys = sort_if_possible(lhs_keys, lhs_path, rhs_path, max_depth) rhs_keys = sort_if_possible(rhs_keys, lhs_path, rhs_path, max_depth) for lhs_key, rhs_key in zip(lhs_keys, rhs_keys): result = generic_object_comparison( lhs_key, rhs_key, lhs_path, rhs_path, max_depth) if result != 0: return result result = generic_object_comparison( lhs[lhs_key], rhs[rhs_key], lhs_path, rhs_path, max_depth) if result != 0: return result lhs_fields = dir(lhs) rhs_fields = dir(rhs) result = compare(len(lhs_fields), len(rhs_fields)) if result != 0: return result for i in range(len(lhs_fields)): result = compare(lhs_fields[i], rhs_fields[i]) if result != 0: return result result = generic_object_comparison( getattr(lhs, lhs_fields[i], None), getattr(rhs, rhs_fields[i], None), lhs_path, rhs_path, max_depth) if result != 0: return result return 0 def sort_if_possible(obj, lhs_path=None, rhs_path=None, max_depth=4): def cmp(lhs, rhs): if lhs_path is None: # Start the traversal at the root call to cmp. return generic_object_comparison(lhs, rhs, [], [], max_depth) else: # Continue the existing traversal path for recursive calls to cmp. return generic_object_comparison(lhs, rhs, lhs_path, rhs_path, max_depth) return sorted(obj, key=functools.cmp_to_key(cmp)) def save_set(pickler, obj): pickler.save_set(sort_if_possible(obj)) def save_frozenset(pickler, obj): pickler.save_frozenset(sort_if_possible(obj))