export function decodeWeights()

in tfjs-core/src/io/io_utils.ts [112:228]


export function decodeWeights(
    buffer: ArrayBuffer, specs: WeightsManifestEntry[]): NamedTensorMap {
  // TODO(adarob, cais): Support quantization.
  const out: NamedTensorMap = {};
  let float16Decode: (buffer: Uint16Array) => Float32Array | undefined;
  let offset = 0;
  for (const spec of specs) {
    const name = spec.name;
    const dtype = spec.dtype;
    const shape = spec.shape;
    const size = sizeFromShape(shape);
    let values: TypedArray|string[]|Uint8Array[];

    if ('quantization' in spec) {
      const quantization = spec.quantization;
      if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
        if (!('min' in quantization && 'scale' in quantization)) {
          throw new Error(
              `Weight ${spec.name} with quantization ${quantization.dtype} ` +
              `doesn't have corresponding metadata min and scale.`);
        }
      } else if (quantization.dtype === 'float16') {
        if (dtype !== 'float32') {
          throw new Error(
              `Weight ${spec.name} is quantized with ${quantization.dtype} ` +
              `which only supports weights of type float32 not ${dtype}.`);
        }
      } else {
        throw new Error(
            `Weight ${spec.name} has unknown ` +
            `quantization dtype ${quantization.dtype}. ` +
            `Supported quantization dtypes are: ` +
            `'uint8', 'uint16', and 'float16'.`);
      }
      const quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype];
      const byteBuffer =
          buffer.slice(offset, offset + size * quantizationSizeFactor);
      const quantizedArray = (quantization.dtype === 'uint8') ?
          new Uint8Array(byteBuffer) :
          new Uint16Array(byteBuffer);
      if (dtype === 'float32') {
        if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
          values = new Float32Array(quantizedArray.length);
          for (let i = 0; i < quantizedArray.length; i++) {
            const v = quantizedArray[i];
            values[i] = v * quantization.scale + quantization.min;
          }
        } else if (quantization.dtype === 'float16') {
          if (float16Decode === undefined) {
            float16Decode = getFloat16Decoder();
          }
          values = float16Decode(quantizedArray as Uint16Array);
        } else {
          throw new Error(
              `Unsupported quantization type ${quantization.dtype} ` +
              `for weight type float32.`);
        }
      } else if (dtype === 'int32') {
        if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') {
          throw new Error(
              `Unsupported quantization type ${quantization.dtype} ` +
              `for weight type int32.`);
        }
        values = new Int32Array(quantizedArray.length);
        for (let i = 0; i < quantizedArray.length; i++) {
          const v = quantizedArray[i];
          values[i] = Math.round(v * quantization.scale + quantization.min);
        }
      } else {
        throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`);
      }
      offset += size * quantizationSizeFactor;
    } else if (dtype === 'string') {
      const size = sizeFromShape(spec.shape);
      values = [];
      for (let i = 0; i < size; i++) {
        const byteLength = new Uint32Array(
            buffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0];
        offset += NUM_BYTES_STRING_LENGTH;
        const bytes = new Uint8Array(buffer.slice(offset, offset + byteLength));
        (values as Uint8Array[]).push(bytes);
        offset += byteLength;
      }
    } else {
      const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype];
      const byteBuffer = buffer.slice(offset, offset + size * dtypeFactor);

      if (dtype === 'float32') {
        values = new Float32Array(byteBuffer);
      } else if (dtype === 'int32') {
        values = new Int32Array(byteBuffer);
      } else if (dtype === 'bool') {
        values = new Uint8Array(byteBuffer);
      } else if (dtype === 'complex64') {
        values = new Float32Array(byteBuffer);
        const real = new Float32Array(values.length / 2);
        const image = new Float32Array(values.length / 2);
        for (let i = 0; i < real.length; i++) {
          real[i] = values[i * 2];
          image[i] = values[i * 2 + 1];
        }
        const realTensor = tensor(real, shape, 'float32');
        const imageTensor = tensor(image, shape, 'float32');
        out[name] = complex(realTensor, imageTensor);
        realTensor.dispose();
        imageTensor.dispose();
      } else {
        throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`);
      }
      offset += size * dtypeFactor;
    }
    if (dtype !== 'complex64') {
      out[name] = tensor(values, shape, dtype);
    }
  }
  return out;
}