in memory_saving_gradients.py [0:0]
def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs):
'''
Authors: Tim Salimans & Yaroslav Bulatov
memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost"
by Chen et al. 2016 (https://arxiv.org/abs/1604.06174)
ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients
(https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients)
'checkpoints' can either be
- a list consisting of tensors from the forward pass of the neural net
that we should re-use when calculating the gradients in the backward pass
all other tensors that do not appear in this list will be re-computed
- a string specifying how this list should be determined. currently we support
- 'speed': checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive,
so checkpointing them maximizes the running speed
(this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory)
- 'memory': try to minimize the memory usage
(currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint)
- 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint
'''
# print("Calling memsaving gradients with", checkpoints)
if not isinstance(ys, list):
ys = [ys]
if not isinstance(xs, list):
xs = [xs]
bwd_ops = ge.get_backward_walk_ops([y.op for y in ys],
inclusive=True)
debug_print("bwd_ops: %s", bwd_ops)
# forward ops are all ops that are candidates for recomputation
fwd_ops = ge.get_forward_walk_ops([x.op for x in xs],
inclusive=True,
within_ops=bwd_ops)
debug_print("fwd_ops: %s", fwd_ops)
# exclude ops with no inputs
fwd_ops = [op for op in fwd_ops if op.inputs]
# don't recompute xs, remove variables
xs_ops = _to_ops(xs)
fwd_ops = [op for op in fwd_ops if not op in xs_ops]
fwd_ops = [op for op in fwd_ops if not '/assign' in op.name]
fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name]
fwd_ops = [op for op in fwd_ops if not '/read' in op.name]
ts_all = ge.filter_ts(fwd_ops, True) # get the tensors
ts_all = [t for t in ts_all if '/read' not in t.name]
ts_all = set(ts_all) - set(xs) - set(ys)
# construct list of tensors to checkpoint during forward pass, if not
# given as input
if type(checkpoints) is not list:
if checkpoints == 'collection':
checkpoints = tf.get_collection('checkpoints')
elif checkpoints == 'speed':
# checkpoint all expensive ops to maximize running speed
checkpoints = ge.filter_ts_from_regex(
fwd_ops, 'conv2d|Conv|MatMul')
elif checkpoints == 'memory':
# remove very small tensors and some weird ops
def fixdims(t): # tf.Dimension values are not compatible with int, convert manually
try:
return [int(e if e.value is not None else 64) for e in t]
except:
return [0] # unknown shape
ts_all = [t for t in ts_all if np.prod(
fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE]
ts_all = [t for t in ts_all if 'L2Loss' not in t.name]
ts_all = [t for t in ts_all if 'entropy' not in t.name]
ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name]
ts_all = [t for t in ts_all if 'Switch' not in t.name]
ts_all = [t for t in ts_all if 'dropout' not in t.name]
# filter out all tensors that are inputs of the backward graph
with util.capture_ops() as bwd_ops:
tf_gradients(ys, xs, grad_ys, **kwargs)
bwd_inputs = [t for op in bwd_ops for t in op.inputs]
# list of tensors in forward graph that is in input to bwd graph
ts_filtered = list(set(bwd_inputs).intersection(ts_all))
debug_print("Using tensors %s", ts_filtered)
# try two slightly different ways of getting bottlenecks tensors
# to checkpoint
for ts in [ts_filtered, ts_all]:
# get all bottlenecks in the graph
bottleneck_ts = []
for t in ts:
b = set(ge.get_backward_walk_ops(
t.op, inclusive=True, within_ops=fwd_ops))
f = set(ge.get_forward_walk_ops(
t.op, inclusive=False, within_ops=fwd_ops))
# check that there are not shortcuts
b_inp = set(
[inp for op in b for inp in op.inputs]).intersection(ts_all)
f_inp = set(
[inp for op in f for inp in op.inputs]).intersection(ts_all)
if not set(b_inp).intersection(f_inp) and len(b_inp)+len(f_inp) >= len(ts_all):
bottleneck_ts.append(t) # we have a bottleneck!
else:
debug_print("Rejected bottleneck candidate and ops %s", [
t] + list(set(ts_all) - set(b_inp) - set(f_inp)))
# success? or try again without filtering?
if len(bottleneck_ts) >= np.sqrt(len(ts_filtered)): # yes, enough bottlenecks found!
break
if not bottleneck_ts:
raise Exception(
'unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".')
# sort the bottlenecks
bottlenecks_sorted_lists = tf_toposort(
bottleneck_ts, within_ops=fwd_ops)
sorted_bottlenecks = [
t for ts in bottlenecks_sorted_lists for t in ts]
# save an approximately optimal number ~ sqrt(N)
N = len(ts_filtered)
if len(bottleneck_ts) <= np.ceil(np.sqrt(N)):
checkpoints = sorted_bottlenecks
else:
step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N)))
checkpoints = sorted_bottlenecks[step::step]
else:
raise Exception(
'%s is unsupported input for "checkpoints"' % (checkpoints,))
checkpoints = list(set(checkpoints).intersection(ts_all))
# at this point automatic selection happened and checkpoints is list of nodes
assert isinstance(checkpoints, list)
debug_print("Checkpoint nodes used: %s", checkpoints)
# better error handling of special cases
# xs are already handled as checkpoint nodes, so no need to include them
xs_intersect_checkpoints = set(xs).intersection(set(checkpoints))
if xs_intersect_checkpoints:
debug_print("Warning, some input nodes are also checkpoint nodes: %s",
xs_intersect_checkpoints)
ys_intersect_checkpoints = set(ys).intersection(set(checkpoints))
debug_print("ys: %s, checkpoints: %s, intersect: %s", ys, checkpoints,
ys_intersect_checkpoints)
# saving an output node (ys) gives no benefit in memory while creating
# new edge cases, exclude them
if ys_intersect_checkpoints:
debug_print("Warning, some output nodes are also checkpoints nodes: %s",
format_ops(ys_intersect_checkpoints))
# remove initial and terminal nodes from checkpoints list if present
checkpoints = list(set(checkpoints) - set(ys) - set(xs))
# check that we have some nodes to checkpoint
if not checkpoints:
raise Exception('no checkpoints nodes found or given as input! ')
# disconnect dependencies between checkpointed tensors
checkpoints_disconnected = {}
for x in checkpoints:
if x.op and x.op.name is not None:
grad_node = tf.stop_gradient(x, name=x.op.name+"_sg")
else:
grad_node = tf.stop_gradient(x)
checkpoints_disconnected[x] = grad_node
# partial derivatives to the checkpointed tensors and xs
ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys],
stop_at_ts=checkpoints, within_ops=fwd_ops)
debug_print("Found %s ops to copy within fwd_ops %s, seed %s, stop_at %s",
len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints)
debug_print("ops_to_copy = %s", ops_to_copy)
debug_print("Processing list %s", ys)
copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
copied_ops = info._transformed_ops.values()
debug_print("Copied %s to %s", ops_to_copy, copied_ops)
ge.reroute_ts(checkpoints_disconnected.values(),
checkpoints_disconnected.keys(), can_modify=copied_ops)
debug_print("Rewired %s in place of %s restricted to %s",
checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops)
# get gradients with respect to current boundary + original x's
copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys]
boundary = list(checkpoints_disconnected.values())
dv = tf_gradients(ys=copied_ys, xs=boundary+xs, grad_ys=grad_ys, **kwargs)
debug_print("Got gradients %s", dv)
debug_print("for %s", copied_ys)
debug_print("with respect to %s", boundary+xs)
inputs_to_do_before = [y.op for y in ys]
if grad_ys is not None:
inputs_to_do_before += grad_ys
wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)
# partial derivatives to the checkpointed nodes
# dictionary of "node: backprop" for nodes in the boundary
d_checkpoints = {r: dr for r, dr in zip(checkpoints_disconnected.keys(),
dv[:len(checkpoints_disconnected)])}
# partial derivatives to xs (usually the params of the neural net)
d_xs = dv[len(checkpoints_disconnected):]
# incorporate derivatives flowing through the checkpointed nodes
checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops)
for ts in checkpoints_sorted_lists[::-1]:
debug_print("Processing list %s", ts)
checkpoints_other = [r for r in checkpoints if r not in ts]
checkpoints_disconnected_other = [
checkpoints_disconnected[r] for r in checkpoints_other]
# copy part of the graph below current checkpoint node, stopping at
# other checkpoints nodes
ops_to_copy = fast_backward_ops(within_ops=fwd_ops, seed_ops=[
r.op for r in ts], stop_at_ts=checkpoints_other)
debug_print("Found %s ops to copy within %s, seed %s, stop_at %s",
len(ops_to_copy), fwd_ops, [r.op for r in ts],
checkpoints_other)
debug_print("ops_to_copy = %s", ops_to_copy)
if not ops_to_copy: # we're done!
break
copied_sgv, info = ge.copy_with_input_replacements(
ge.sgv(ops_to_copy), {})
copied_ops = info._transformed_ops.values()
debug_print("Copied %s to %s", ops_to_copy, copied_ops)
ge.reroute_ts(checkpoints_disconnected_other,
checkpoints_other, can_modify=copied_ops)
debug_print("Rewired %s in place of %s restricted to %s",
checkpoints_disconnected_other, checkpoints_other, copied_ops)
# gradient flowing through the checkpointed node
boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts]
substitute_backprops = [d_checkpoints[r] for r in ts]
dv = tf_gradients(boundary,
checkpoints_disconnected_other+xs,
grad_ys=substitute_backprops, **kwargs)
debug_print("Got gradients %s", dv)
debug_print("for %s", boundary)
debug_print("with respect to %s", checkpoints_disconnected_other+xs)
debug_print("with boundary backprop substitutions %s",
substitute_backprops)
inputs_to_do_before = [d_checkpoints[r].op for r in ts]
wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)
# partial derivatives to the checkpointed nodes
for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]):
if dr is not None:
if d_checkpoints[r] is None:
d_checkpoints[r] = dr
else:
d_checkpoints[r] += dr
# partial derivatives to xs (usually the params of the neural net)
d_xs_new = dv[len(checkpoints_other):]
for j in range(len(xs)):
if d_xs_new[j] is not None:
if d_xs[j] is None:
d_xs[j] = d_xs_new[j]
else:
d_xs[j] += d_xs_new[j]
return d_xs