in tfjs-core/src/io/io_utils.ts [112:228]
export function decodeWeights(
buffer: ArrayBuffer, specs: WeightsManifestEntry[]): NamedTensorMap {
// TODO(adarob, cais): Support quantization.
const out: NamedTensorMap = {};
let float16Decode: (buffer: Uint16Array) => Float32Array | undefined;
let offset = 0;
for (const spec of specs) {
const name = spec.name;
const dtype = spec.dtype;
const shape = spec.shape;
const size = sizeFromShape(shape);
let values: TypedArray|string[]|Uint8Array[];
if ('quantization' in spec) {
const quantization = spec.quantization;
if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
if (!('min' in quantization && 'scale' in quantization)) {
throw new Error(
`Weight ${spec.name} with quantization ${quantization.dtype} ` +
`doesn't have corresponding metadata min and scale.`);
}
} else if (quantization.dtype === 'float16') {
if (dtype !== 'float32') {
throw new Error(
`Weight ${spec.name} is quantized with ${quantization.dtype} ` +
`which only supports weights of type float32 not ${dtype}.`);
}
} else {
throw new Error(
`Weight ${spec.name} has unknown ` +
`quantization dtype ${quantization.dtype}. ` +
`Supported quantization dtypes are: ` +
`'uint8', 'uint16', and 'float16'.`);
}
const quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype];
const byteBuffer =
buffer.slice(offset, offset + size * quantizationSizeFactor);
const quantizedArray = (quantization.dtype === 'uint8') ?
new Uint8Array(byteBuffer) :
new Uint16Array(byteBuffer);
if (dtype === 'float32') {
if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
values = new Float32Array(quantizedArray.length);
for (let i = 0; i < quantizedArray.length; i++) {
const v = quantizedArray[i];
values[i] = v * quantization.scale + quantization.min;
}
} else if (quantization.dtype === 'float16') {
if (float16Decode === undefined) {
float16Decode = getFloat16Decoder();
}
values = float16Decode(quantizedArray as Uint16Array);
} else {
throw new Error(
`Unsupported quantization type ${quantization.dtype} ` +
`for weight type float32.`);
}
} else if (dtype === 'int32') {
if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') {
throw new Error(
`Unsupported quantization type ${quantization.dtype} ` +
`for weight type int32.`);
}
values = new Int32Array(quantizedArray.length);
for (let i = 0; i < quantizedArray.length; i++) {
const v = quantizedArray[i];
values[i] = Math.round(v * quantization.scale + quantization.min);
}
} else {
throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`);
}
offset += size * quantizationSizeFactor;
} else if (dtype === 'string') {
const size = sizeFromShape(spec.shape);
values = [];
for (let i = 0; i < size; i++) {
const byteLength = new Uint32Array(
buffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0];
offset += NUM_BYTES_STRING_LENGTH;
const bytes = new Uint8Array(buffer.slice(offset, offset + byteLength));
(values as Uint8Array[]).push(bytes);
offset += byteLength;
}
} else {
const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype];
const byteBuffer = buffer.slice(offset, offset + size * dtypeFactor);
if (dtype === 'float32') {
values = new Float32Array(byteBuffer);
} else if (dtype === 'int32') {
values = new Int32Array(byteBuffer);
} else if (dtype === 'bool') {
values = new Uint8Array(byteBuffer);
} else if (dtype === 'complex64') {
values = new Float32Array(byteBuffer);
const real = new Float32Array(values.length / 2);
const image = new Float32Array(values.length / 2);
for (let i = 0; i < real.length; i++) {
real[i] = values[i * 2];
image[i] = values[i * 2 + 1];
}
const realTensor = tensor(real, shape, 'float32');
const imageTensor = tensor(image, shape, 'float32');
out[name] = complex(realTensor, imageTensor);
realTensor.dispose();
imageTensor.dispose();
} else {
throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`);
}
offset += size * dtypeFactor;
}
if (dtype !== 'complex64') {
out[name] = tensor(values, shape, dtype);
}
}
return out;
}