in tfjs-layers/src/engine/executor.ts [234:364]
export function execute(
fetches: SymbolicTensor|SymbolicTensor[], feedDict: FeedDict,
kwargs?: Kwargs, probe?: ExecutionProbe): Tensor|
Tensor[]|[Tensor | Tensor[]] {
const training: boolean = kwargs == null ? false : kwargs['training'];
const arrayFetches = Array.isArray(fetches);
const fetchArray: SymbolicTensor[] =
arrayFetches ? fetches as SymbolicTensor[] : [fetches as SymbolicTensor];
const outputNames = fetchArray.map(t => t.name);
const finalOutputs: Tensor[] = [];
const feedNames = feedDict.names();
for (const outputName of outputNames) {
if (feedNames.indexOf(outputName) !== -1) {
finalOutputs.push(feedDict.getValue(outputName));
} else {
finalOutputs.push(null);
}
}
if (probe != null) {
// For optional probing of memory footprint during execution.
probe.maxNumTensors = -Infinity;
probe.minNumTensors = Infinity;
}
// Check cache.
const fetchAndFeedKey =
outputNames.join(',') + '|' + feedDict.names().join(',');
let sorted: SymbolicTensor[];
let recipientCounts: {[fetchName: string]: number};
if (cachedSorted[fetchAndFeedKey] == null) {
// Cache doesn't contain the desired combination of fetches. Compute
// topological sort for the combination for the first time.
const out = getTopologicalSortAndRecipientCounts(fetchArray, feedDict);
sorted = out.sorted;
recipientCounts = out.recipientCounts;
// Store results in cache for future use.
cachedSorted[fetchAndFeedKey] = sorted;
cachedRecipientCounts[fetchAndFeedKey] = recipientCounts;
}
sorted = cachedSorted[fetchAndFeedKey];
recipientCounts = {};
if (!training) {
Object.assign(recipientCounts, cachedRecipientCounts[fetchAndFeedKey]);
}
const internalFeedDict = new FeedDict(feedDict);
// Start iterative execution on the topologically-sorted SymbolicTensors.
for (let i = 0; i < sorted.length; ++i) {
if (probe != null) {
// For optional probing of memory usage during execution.
const numTensors = memory().numTensors;
if (numTensors > probe.maxNumTensors) {
probe.maxNumTensors = numTensors;
}
if (numTensors < probe.minNumTensors) {
probe.minNumTensors = numTensors;
}
}
const symbolic = sorted[i];
const srcLayer = symbolic.sourceLayer;
if (srcLayer instanceof InputLayer) {
continue;
}
const inputValues: Tensor[] = [];
const inputMasks: Tensor[] = [];
const tensorsToDispose: Tensor[] = [];
let maskExists = false;
for (const input of symbolic.inputs) {
const value = internalFeedDict.getValue(input);
const mask = internalFeedDict.getMask(input);
inputValues.push(value);
inputMasks.push(mask);
if (mask != null) {
maskExists = true;
}
if (!training) {
recipientCounts[input.name]--;
if (recipientCounts[input.name] === 0 && !feedDict.hasKey(input) &&
outputNames.indexOf(input.name) === -1 && !value.isDisposed &&
input.sourceLayer.stateful !== true) {
tensorsToDispose.push(value);
}
}
}
if (maskExists) {
kwargs = kwargs || {};
kwargs['mask'] = inputMasks[0];
}
const outputTensors =
toList(srcLayer.apply(inputValues, kwargs)) as Tensor[];
let outputMask: Tensor|Tensor[] = null;
if (srcLayer.supportsMasking) {
outputMask = srcLayer.computeMask(inputValues, inputMasks);
}
const layerOutputs = getNodeOutputs(symbolic);
const outputSymbolicTensors =
Array.isArray(layerOutputs) ? layerOutputs : [layerOutputs];
for (let i = 0; i < outputSymbolicTensors.length; ++i) {
if (!internalFeedDict.hasKey(outputSymbolicTensors[i])) {
internalFeedDict.add(
outputSymbolicTensors[i], outputTensors[i],
Array.isArray(outputMask) ? outputMask[0] : outputMask);
}
const index = outputNames.indexOf(outputSymbolicTensors[i].name);
if (index !== -1) {
finalOutputs[index] = outputTensors[i];
}
}
if (!training) {
// Clean up Tensors that are no longer needed.
dispose(tensorsToDispose);
}
}
// NOTE(cais): Unlike intermediate tensors, we don't discard mask
// tensors as we go, because these tensors are sometimes passed over a
// series of mutliple layers, i.e., not obeying the immediate input
// relations in the graph. If this becomes a memory-usage concern,
// we can improve this in the future.
internalFeedDict.disposeMasks();
return arrayFetches ? finalOutputs : finalOutputs[0];
}