in rfcs/20210731-tfjs-named-tensors/gtensor.ts [64:112]
export function dot<D1 extends G1, D2 extends G2, G1 extends DName, G2 extends DName>(
d1: Dimension<G1, D1>,
maybed2: DotCompatibleDimension<G1,D1,G2,D2>
): GTensor<Exclude<G1|G2, D1>> {
// TODO: maybe we canmake the type system do more for us... D extends D2 ? (D2
// extends D ? GTensor<Exclude<G1|G2, D>> : never) : never
//
// TODO: We use `tf.einsum`, and consturct the inputs for it via
// strings; this is quite a bit of indirection, and likely we could use an
// underlying API directly and it save string construction and parsing.
//
// TODO: think about if the 'never' below is needed.
let d2 = maybed2 as never as Dimension<G2, D2>;
const FIRST_CHAR_CODE_FOR_d1 = 'A'.charCodeAt(0);
const d1Names = d1.gtensor.dimNames;
const d1CharNames = d1Names.map(
(n, i) => String.fromCharCode(FIRST_CHAR_CODE_FOR_d1 + i));
const d1CharName = d1CharNames[d1.index];
const FIRST_CHAR_CODE_FOR_d2 = FIRST_CHAR_CODE_FOR_d1 + d1CharNames.length;
const d2Names = d2.gtensor.dimNames;
const d2CharNames = d2Names.map(
(n, i) => String.fromCharCode(FIRST_CHAR_CODE_FOR_d2 + i));
// const d2CharName = d1CharNames[d2.index];
d2CharNames.splice(d2.index, 1, d1CharName);
if ((d2CharNames.length + d1CharNames.length) > 52) {
console.warn('');
}
const resultCharNames = d1CharNames.concat(d2CharNames).filter(
c => c !== d1CharName);
const einsumStr = `${d1CharNames.join('')},${d2CharNames.join('')}->${
resultCharNames.join('')}`;
// console.log(einsumStr);
const resultTensor = tf.einsum(
einsumStr, d1.gtensor.tensor, d2.gtensor.tensor);
const newNames =
(d1Names.slice(0, d1.index) as DName[])
.concat(d1Names.slice(d1.index + 1, d1Names.length))
.concat(d2Names.slice(0, d2.index) as DName[])
.concat(d2Names.slice(d2.index + 1, d2Names.length));
return new GTensor(resultTensor, newNames as (Exclude<G1|G2, D1>)[]);
}