in memory_saving_gradients.py [0:0]
def tf_toposort(ts, within_ops=None):
all_ops = ge.get_forward_walk_ops(
[x.op for x in ts], within_ops=within_ops)
deps = {}
for op in all_ops:
for o in op.outputs:
deps[o] = set(op.inputs)
sorted_ts = toposort(deps)
# only keep the tensors from our original list
ts_sorted_lists = []
for l in sorted_ts:
keep = list(set(l).intersection(ts))
if keep:
ts_sorted_lists.append(keep)
return ts_sorted_lists