in core/src/main/scala/com/microsoft/azure/synapse/ml/explainers/KernelSHAPSampler.scala [44:127]
private[explainers] def generateSampleSizes(m: Int, numSamples: Int)
(kernelWeightFunc: Int => Double): Array[(Int, Double)] = {
assert(numSamples <= math.pow(2, m) - 2)
assert(numSamples > 0)
assert(m > 0)
val (numSubsets, numPairedSubset) = (m / 2, (m - 1) / 2)
val weightsSeq = (1 to numSubsets) map (i => (m - 1).toDouble / (i * (m - i)))
val weights = new BDV[Double](weightsSeq.toArray)
weights(0 until numPairedSubset) *= 2.0
@tailrec
def recurse(k: Int, samplesLeft: Int, acc: Array[(Int, Double)]): Array[(Int, Double)] = {
assert(samplesLeft > 0)
if (k > numSubsets) {
acc
} else {
val kernelWeight = kernelWeightFunc(k)
val rescaledWeights = rescale(weights(k - 1 to -1))
val paired = k <= numPairedSubset
val combo = if (paired) comb(m, k).toLong * 2 else comb(m, k).toLong
val allocation = rescaledWeights(0) * samplesLeft
val subsetSizes = {
if (allocation >= combo) {
// subset of 'comb' size is filled up.
if (paired) {
(combo / 2).toInt :: (combo / 2).toInt :: Nil
} else {
combo.toInt :: Nil
}
} else {
Nil
}
}
if (subsetSizes.isEmpty) {
acc
} else {
val newSamplesLeft = samplesLeft - sum(subsetSizes)
if (newSamplesLeft > 0) {
recurse(k + 1, newSamplesLeft, acc ++ subsetSizes.map(s => (s, kernelWeight)))
} else {
acc ++ subsetSizes.map(s => (s, kernelWeight))
}
}
}
}
@tailrec
def allocateRemainingSamples(k: Int, samplesLeft: Int, acc: Array[Int]): Array[Int] = {
val rescaledWeights = rescale(weights(k - 1 to -1))
val paired = k <= numPairedSubset
val allocation = rescaledWeights(0) * samplesLeft
val subsetSizes = {
if (paired) {
if (allocation >= 1) {
val half = math.ceil(allocation / 2).toInt
half :: half :: Nil
} else {
1 :: 0 :: Nil
}
} else {
allocation.toInt :: Nil
}
}
val newSamplesLeft = samplesLeft - sum(subsetSizes)
if (newSamplesLeft > 0) {
allocateRemainingSamples(k + 1, newSamplesLeft, acc ++ subsetSizes)
} else {
acc ++ subsetSizes
}
}
val result = recurse(1, numSamples, Array.empty)
val remainingSamples = numSamples - sum(result.map(_._1))
if (remainingSamples == 0) {
result
} else {
result ++ allocateRemainingSamples(result.length / 2 + 1, remainingSamples, Array.empty).map(s => (s, 1.0))
}
}