in tfjs-converter/src/operations/operation_mapper.ts [179:336]
private mapNode(node: tensorflow.INodeDef): Node {
// Unsupported ops will cause an error at run-time (not parse time), since
// they may not be used by the actual execution subgraph.
const mapper =
getRegisteredOp(node.op) || this.opMappers[node.op] || {} as OpMapper;
if (node.attr == null) {
node.attr = {};
}
const newNode: Node = {
name: node.name,
op: node.op,
category: mapper.category,
inputNames:
(node.input ||
[]).map(input => input.startsWith('^') ? input.substr(1) : input),
inputs: [],
children: [],
inputParams: {},
attrParams: {},
rawAttrs: node.attr,
outputs: mapper.outputs
};
if (mapper.inputs != null) {
newNode.inputParams =
mapper.inputs.reduce<{[key: string]: InputParamValue}>(
(map, param) => {
map[param.name] = {
type: param.type,
inputIndexStart: param.start,
inputIndexEnd: param.end
};
return map;
},
{});
}
if (mapper.attrs != null) {
newNode.attrParams =
mapper.attrs.reduce<{[key: string]: ParamValue}>((map, param) => {
const type = param.type;
let value = undefined;
switch (param.type) {
case 'string':
value = getStringParam(
node.attr, param.tfName, param.defaultValue as string);
if (value === undefined && !!param.tfDeprecatedName) {
value = getStringParam(
node.attr, param.tfDeprecatedName,
param.defaultValue as string);
}
break;
case 'string[]':
value = getStringArrayParam(
node.attr, param.tfName, param.defaultValue as string[]);
if (value === undefined && !!param.tfDeprecatedName) {
value = getStringArrayParam(
node.attr, param.tfDeprecatedName,
param.defaultValue as string[]);
}
break;
case 'number':
value = getNumberParam(
node.attr, param.tfName,
(param.defaultValue || 0) as number);
if (value === undefined && !!param.tfDeprecatedName) {
value = getNumberParam(
node.attr, param.tfDeprecatedName,
param.defaultValue as number);
}
break;
case 'number[]':
value = getNumericArrayParam(
node.attr, param.tfName, param.defaultValue as number[]);
if (value === undefined && !!param.tfDeprecatedName) {
value = getNumericArrayParam(
node.attr, param.tfDeprecatedName,
param.defaultValue as number[]);
}
break;
case 'bool':
value = getBoolParam(
node.attr, param.tfName, param.defaultValue as boolean);
if (value === undefined && !!param.tfDeprecatedName) {
value = getBoolParam(
node.attr, param.tfDeprecatedName,
param.defaultValue as boolean);
}
break;
case 'bool[]':
value = getBoolArrayParam(
node.attr, param.tfName, param.defaultValue as boolean[]);
if (value === undefined && !!param.tfDeprecatedName) {
value = getBoolArrayParam(
node.attr, param.tfDeprecatedName,
param.defaultValue as boolean[]);
}
break;
case 'shape':
value = getTensorShapeParam(
node.attr, param.tfName, param.defaultValue as number[]);
if (value === undefined && !!param.tfDeprecatedName) {
value = getTensorShapeParam(
node.attr, param.tfDeprecatedName,
param.defaultValue as number[]);
}
break;
case 'shape[]':
value = getTensorShapeArrayParam(
node.attr, param.tfName, param.defaultValue as number[][]);
if (value === undefined && !!param.tfDeprecatedName) {
value = getTensorShapeArrayParam(
node.attr, param.tfDeprecatedName,
param.defaultValue as number[][]);
}
break;
case 'dtype':
value = getDtypeParam(
node.attr, param.tfName, param.defaultValue as DataType);
if (value === undefined && !!param.tfDeprecatedName) {
value = getDtypeParam(
node.attr, param.tfDeprecatedName,
param.defaultValue as DataType);
}
break;
case 'dtype[]':
value = getDtypeArrayParam(
node.attr, param.tfName, param.defaultValue as DataType[]);
if (value === undefined && !!param.tfDeprecatedName) {
value = getDtypeArrayParam(
node.attr, param.tfDeprecatedName,
param.defaultValue as DataType[]);
}
break;
case 'func':
value = getFuncParam(
node.attr, param.tfName, param.defaultValue as string);
if (value === undefined && !!param.tfDeprecatedName) {
value = getFuncParam(
node.attr, param.tfDeprecatedName,
param.defaultValue as string);
}
break;
case 'tensor':
case 'tensors':
break;
default:
throw new Error(
`Unsupported param type: ${param.type} for op: ${node.op}`);
}
map[param.name] = {value, type};
return map;
}, {});
}
return newNode;
}