in speech-commands/src/training_utils.ts [93:164]
export function balancedTrainValSplitNumArrays(
xs: number[][]|Float32Array[], ys: number[], valSplit: number): {
trainXs: number[][]|Float32Array[],
trainYs: number[],
valXs: number[][]|Float32Array[],
valYs: number[]
} {
tf.util.assert(
valSplit > 0 && valSplit < 1,
() => `validationSplit is expected to be >0 and <1, ` +
`but got ${valSplit}`);
const isXsFloat32Array = !Array.isArray(xs[0]);
const classIndices = ys;
const indicesByClasses: number[][] = [];
for (let i = 0; i < classIndices.length; ++i) {
const classIndex = classIndices[i];
if (indicesByClasses[classIndex] == null) {
indicesByClasses[classIndex] = [];
}
indicesByClasses[classIndex].push(i);
}
const numClasses = indicesByClasses.length;
const trainIndices: number[] = [];
const valIndices: number[] = [];
// Randomly shuffle the list of indices in each array.
indicesByClasses.map(classIndices => tf.util.shuffle(classIndices));
for (let i = 0; i < numClasses; ++i) {
const classIndices = indicesByClasses[i];
const cutoff = Math.round(classIndices.length * (1 - valSplit));
for (let j = 0; j < classIndices.length; ++j) {
if (j < cutoff) {
trainIndices.push(classIndices[j]);
} else {
valIndices.push(classIndices[j]);
}
}
}
if (isXsFloat32Array) {
const trainXs: Float32Array[] = [];
const trainYs: number[] = [];
const valXs: Float32Array[] = [];
const valYs: number[] = [];
for (const index of trainIndices) {
trainXs.push(xs[index] as Float32Array);
trainYs.push(ys[index]);
}
for (const index of valIndices) {
valXs.push(xs[index] as Float32Array);
valYs.push(ys[index]);
}
return {trainXs, trainYs, valXs, valYs};
} else {
const trainXs: number[][] = [];
const trainYs: number[] = [];
const valXs: number[][] = [];
const valYs: number[] = [];
for (const index of trainIndices) {
trainXs.push(xs[index] as number[]);
trainYs.push(ys[index]);
}
for (const index of valIndices) {
valXs.push(xs[index] as number[]);
valYs.push(ys[index]);
}
return {trainXs, trainYs, valXs, valYs};
}
}