sdks/python/apache_beam/internal/set_pickler.py (95 lines of code) (raw):
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Custom pickling logic for sets to make the serialization semi-deterministic.
To make set serialization semi-deterministic, we must pick an order for the set
elements. Sets may contain elements of types not defining a comparison "<"
operator. To provide an order, we define our own custom comparison function
which supports elements of near-arbitrary types and use that to sort the
contents of each set during serialization. Attempts at determinism are made on a
best-effort basis to improve hit rates for cached workflows and the ordering
does not define a total order for all values.
"""
import enum
import functools
def compare(lhs, rhs):
"""Returns -1, 0, or 1 depending on whether lhs <, =, or > rhs."""
if lhs < rhs:
return -1
elif lhs > rhs:
return 1
else:
return 0
def generic_object_comparison(lhs, rhs, lhs_path, rhs_path, max_depth):
"""Identifies which object goes first in an (almost) total order of objects.
Args:
lhs: An arbitrary Python object or built-in type.
rhs: An arbitrary Python object or built-in type.
lhs_path: Traversal path from the root lhs object up to, but not including,
lhs. The original contents of lhs_path are restored before the function
returns.
rhs_path: Same as lhs_path except for the rhs.
max_depth: Maximum recursion depth.
Returns:
-1, 0, or 1 depending on whether lhs or rhs goes first in the total order.
0 if max_depth is exhausted.
0 if lhs is in lhs_path or rhs is in rhs_path (there is a cycle).
"""
if id(lhs) == id(rhs):
# Fast path
return 0
if type(lhs) != type(rhs):
return compare(str(type(lhs)), str(type(rhs)))
if type(lhs) in [int, float, bool, str, bool, bytes, bytearray]:
return compare(lhs, rhs)
if isinstance(lhs, enum.Enum):
# Enums can have values with arbitrary types. The names are strings.
return compare(lhs.name, rhs.name)
# To avoid exceeding the recursion depth limit, set a limit on recursion.
max_depth -= 1
if max_depth < 0:
return 0
# Check for cycles in the traversal path to avoid getting stuck in a loop.
if id(lhs) in lhs_path or id(rhs) in rhs_path:
return 0
lhs_path.append(id(lhs))
rhs_path.append(id(rhs))
# The comparison logic is split across two functions to simplifying updating
# and restoring the traversal paths.
result = _generic_object_comparison_recursive_path(
lhs, rhs, lhs_path, rhs_path, max_depth)
lhs_path.pop()
rhs_path.pop()
return result
def _generic_object_comparison_recursive_path(
lhs, rhs, lhs_path, rhs_path, max_depth):
if type(lhs) == tuple or type(lhs) == list:
result = compare(len(lhs), len(rhs))
if result != 0:
return result
for i in range(len(lhs)):
result = generic_object_comparison(
lhs[i], rhs[i], lhs_path, rhs_path, max_depth)
if result != 0:
return result
return 0
if type(lhs) == frozenset or type(lhs) == set:
return generic_object_comparison(
tuple(sort_if_possible(lhs, lhs_path, rhs_path, max_depth)),
tuple(sort_if_possible(rhs, lhs_path, rhs_path, max_depth)),
lhs_path,
rhs_path,
max_depth)
if type(lhs) == dict:
lhs_keys = list(lhs.keys())
rhs_keys = list(rhs.keys())
result = compare(len(lhs_keys), len(rhs_keys))
if result != 0:
return result
lhs_keys = sort_if_possible(lhs_keys, lhs_path, rhs_path, max_depth)
rhs_keys = sort_if_possible(rhs_keys, lhs_path, rhs_path, max_depth)
for lhs_key, rhs_key in zip(lhs_keys, rhs_keys):
result = generic_object_comparison(
lhs_key, rhs_key, lhs_path, rhs_path, max_depth)
if result != 0:
return result
result = generic_object_comparison(
lhs[lhs_key], rhs[rhs_key], lhs_path, rhs_path, max_depth)
if result != 0:
return result
lhs_fields = dir(lhs)
rhs_fields = dir(rhs)
result = compare(len(lhs_fields), len(rhs_fields))
if result != 0:
return result
for i in range(len(lhs_fields)):
result = compare(lhs_fields[i], rhs_fields[i])
if result != 0:
return result
result = generic_object_comparison(
getattr(lhs, lhs_fields[i], None),
getattr(rhs, rhs_fields[i], None),
lhs_path,
rhs_path,
max_depth)
if result != 0:
return result
return 0
def sort_if_possible(obj, lhs_path=None, rhs_path=None, max_depth=4):
def cmp(lhs, rhs):
if lhs_path is None:
# Start the traversal at the root call to cmp.
return generic_object_comparison(lhs, rhs, [], [], max_depth)
else:
# Continue the existing traversal path for recursive calls to cmp.
return generic_object_comparison(lhs, rhs, lhs_path, rhs_path, max_depth)
return sorted(obj, key=functools.cmp_to_key(cmp))
def save_set(pickler, obj):
pickler.save_set(sort_if_possible(obj))
def save_frozenset(pickler, obj):
pickler.save_frozenset(sort_if_possible(obj))