export function balancedTrainValSplitNumArrays()

in speech-commands/src/training_utils.ts [93:164]


export function balancedTrainValSplitNumArrays(
    xs: number[][]|Float32Array[], ys: number[], valSplit: number): {
  trainXs: number[][]|Float32Array[],
  trainYs: number[],
  valXs: number[][]|Float32Array[],
  valYs: number[]
} {
  tf.util.assert(
      valSplit > 0 && valSplit < 1,
      () => `validationSplit is expected to be >0 and <1, ` +
          `but got ${valSplit}`);
  const isXsFloat32Array = !Array.isArray(xs[0]);

  const classIndices = ys;

  const indicesByClasses: number[][] = [];
  for (let i = 0; i < classIndices.length; ++i) {
    const classIndex = classIndices[i];
    if (indicesByClasses[classIndex] == null) {
      indicesByClasses[classIndex] = [];
    }
    indicesByClasses[classIndex].push(i);
  }
  const numClasses = indicesByClasses.length;

  const trainIndices: number[] = [];
  const valIndices: number[] = [];

  // Randomly shuffle the list of indices in each array.
  indicesByClasses.map(classIndices => tf.util.shuffle(classIndices));
  for (let i = 0; i < numClasses; ++i) {
    const classIndices = indicesByClasses[i];
    const cutoff = Math.round(classIndices.length * (1 - valSplit));
    for (let j = 0; j < classIndices.length; ++j) {
      if (j < cutoff) {
        trainIndices.push(classIndices[j]);
      } else {
        valIndices.push(classIndices[j]);
      }
    }
  }

  if (isXsFloat32Array) {
    const trainXs: Float32Array[] = [];
    const trainYs: number[] = [];
    const valXs: Float32Array[] = [];
    const valYs: number[] = [];
    for (const index of trainIndices) {
      trainXs.push(xs[index] as Float32Array);
      trainYs.push(ys[index]);
    }
    for (const index of valIndices) {
      valXs.push(xs[index] as Float32Array);
      valYs.push(ys[index]);
    }
    return {trainXs, trainYs, valXs, valYs};
  } else {
    const trainXs: number[][] = [];
    const trainYs: number[] = [];
    const valXs: number[][] = [];
    const valYs: number[] = [];
    for (const index of trainIndices) {
      trainXs.push(xs[index] as number[]);
      trainYs.push(ys[index]);
    }
    for (const index of valIndices) {
      valXs.push(xs[index] as number[]);
      valYs.push(ys[index]);
    }
    return {trainXs, trainYs, valXs, valYs};
  }
}