export function dot()

in rfcs/20210731-tfjs-named-tensors/gtensor.ts [64:112]


export function dot<D1 extends G1, D2 extends G2, G1 extends DName, G2 extends DName>(
  d1: Dimension<G1, D1>,
  maybed2: DotCompatibleDimension<G1,D1,G2,D2>
): GTensor<Exclude<G1|G2, D1>> {
  // TODO: maybe we canmake the type system do more for us... D extends D2 ? (D2
  //   extends D ? GTensor<Exclude<G1|G2, D>> : never) : never
  //
  // TODO: We use `tf.einsum`, and consturct the inputs for it via
  // strings; this is quite a bit of indirection, and likely we could use an
  // underlying API directly and it save string construction and parsing.
  //
  // TODO: think about if the 'never' below is needed.
  let d2 = maybed2 as never as Dimension<G2, D2>;

  const FIRST_CHAR_CODE_FOR_d1 = 'A'.charCodeAt(0);

  const d1Names = d1.gtensor.dimNames;
  const d1CharNames = d1Names.map(
    (n, i) => String.fromCharCode(FIRST_CHAR_CODE_FOR_d1 + i));
  const d1CharName = d1CharNames[d1.index];

  const FIRST_CHAR_CODE_FOR_d2 = FIRST_CHAR_CODE_FOR_d1 + d1CharNames.length;
  const d2Names = d2.gtensor.dimNames;
  const d2CharNames = d2Names.map(
    (n, i) => String.fromCharCode(FIRST_CHAR_CODE_FOR_d2 + i));
  // const d2CharName = d1CharNames[d2.index];
  d2CharNames.splice(d2.index, 1, d1CharName);

  if ((d2CharNames.length + d1CharNames.length) > 52) {
    console.warn('');
  }

  const resultCharNames = d1CharNames.concat(d2CharNames).filter(
    c => c !== d1CharName);

  const einsumStr = `${d1CharNames.join('')},${d2CharNames.join('')}->${
    resultCharNames.join('')}`;
  // console.log(einsumStr);

  const resultTensor = tf.einsum(
    einsumStr, d1.gtensor.tensor, d2.gtensor.tensor);

  const newNames =
    (d1Names.slice(0, d1.index) as DName[])
    .concat(d1Names.slice(d1.index + 1, d1Names.length))
    .concat(d2Names.slice(0, d2.index) as DName[])
    .concat(d2Names.slice(d2.index + 1, d2Names.length));
  return new GTensor(resultTensor, newNames as (Exclude<G1|G2, D1>)[]);
}