in mesh_tensorflow/auto_mtf/scheduler.py [0:0]
def _minimize_peak_memory_list(graph):
"""Computes schedule according to the greedy list heuristic.
Greedy list heuristic: schedule the operation which results in the most bytes
of memory being (immediately) freed.
TODO(joshuawang): Experiment with tiebreaking by preferring more successors.
Args:
graph: an mtf.auto_mtf.graph_interface.GraphInterface.
Returns:
an iterable of integers representing the schedule.
"""
schedule = []
bytes_freed = {} # {operation_name: bytes freed}
users_of = collections.defaultdict(set) # {tensor_name: set(operation_name)}
in_degree = collections.defaultdict(int) # {operation_name: in degree}
operation_id = {} # {operation_name: id}
# We want an updatable priority queue, so we use the following workaround:
# docs.python.org/2/library/heapq.html#priority-queue-implementation-notes
priority_queue = [] # (negative bytes freed, operation name)
# Set up the (greedy) topological sort.
for i, operation_name in enumerate(graph.get_all_operation_names()):
operation_id[operation_name] = i
for input_name in graph.get_operation_input_names(operation_name):
# Note that in _HybridGraphInterface, an operation may use a tensor twice,
# but we deduplicate (with respect to in_degree) so that we can later use
# users_of to decrement in_degree.
if operation_name in users_of[input_name]:
continue
users_of[input_name].add(operation_name)
in_degree[operation_name] += 1
for operation_name in graph.get_all_operation_names():
bytes_freed[operation_name] = 0
# For each input, this operation frees memory if it is the final consumer.
for input_name in graph.get_operation_input_names(operation_name):
if len(users_of[input_name]) == 1 and not graph.is_tensor_final(
input_name):
bytes_freed[operation_name] += graph.get_tensor_size(input_name)
# For each output, this operation will require additional bytes of memory
# (hence negative bytes freed).
for output_name in graph.get_operation_output_names(operation_name):
# If the output is used (or is final), then it eats memory.
if users_of[output_name] or graph.is_tensor_final(output_name):
bytes_freed[operation_name] -= graph.get_tensor_size(output_name)
for operation_name in graph.get_all_operation_names():
if in_degree[operation_name] == 0:
heapq.heappush(priority_queue,
(-bytes_freed[operation_name], operation_name))
# Do the (greedy) topological sort.
while priority_queue:
neg_bytes_freed, operation_name = heapq.heappop(priority_queue)
if bytes_freed[operation_name] != -neg_bytes_freed:
continue
schedule.append(operation_id[operation_name])
bytes_freed[operation_name] = None
for output_name in graph.get_operation_output_names(operation_name):
for other_operation_name in users_of[output_name]:
in_degree[other_operation_name] -= 1
if in_degree[other_operation_name] == 0:
heapq.heappush(priority_queue,
(-bytes_freed[other_operation_name],
other_operation_name))
for input_name in graph.get_operation_input_names(operation_name):
if operation_name not in users_of[input_name]:
# Used twice by this operation and hence already removed.
continue
users_of[input_name].remove(operation_name)
if len(users_of[input_name]) != 1 or graph.is_tensor_final(output_name):
continue
(other_operation_name,) = users_of[input_name]
bytes_freed[other_operation_name] += graph.get_tensor_size(
input_name)
if in_degree[other_operation_name] > 0:
continue
# Push another copy into the priority queue with our updated value.
# The original copy will be ignored since it does not match bytes_freed.
heapq.heappush(priority_queue, (-bytes_freed[other_operation_name],
other_operation_name))
return schedule