private[explainers] def generateSampleSizes()

in core/src/main/scala/com/microsoft/azure/synapse/ml/explainers/KernelSHAPSampler.scala [44:127]


  private[explainers] def generateSampleSizes(m: Int, numSamples: Int)
                                             (kernelWeightFunc: Int => Double): Array[(Int, Double)] = {
    assert(numSamples <= math.pow(2, m) - 2)
    assert(numSamples > 0)
    assert(m > 0)

    val (numSubsets, numPairedSubset) = (m / 2, (m - 1) / 2)
    val weightsSeq = (1 to numSubsets) map (i => (m - 1).toDouble / (i * (m - i)))
    val weights = new BDV[Double](weightsSeq.toArray)
    weights(0 until numPairedSubset) *= 2.0

    @tailrec
    def recurse(k: Int, samplesLeft: Int, acc: Array[(Int, Double)]): Array[(Int, Double)] = {
      assert(samplesLeft > 0)
      if (k > numSubsets) {
        acc
      } else {
        val kernelWeight = kernelWeightFunc(k)
        val rescaledWeights = rescale(weights(k - 1 to -1))
        val paired = k <= numPairedSubset
        val combo = if (paired) comb(m, k).toLong * 2 else comb(m, k).toLong
        val allocation = rescaledWeights(0) * samplesLeft
        val subsetSizes = {
          if (allocation >= combo) {
            // subset of 'comb' size is filled up.
            if (paired) {
              (combo / 2).toInt :: (combo / 2).toInt :: Nil
            } else {
              combo.toInt :: Nil
            }
          } else {
            Nil
          }
        }

        if (subsetSizes.isEmpty) {
          acc
        } else {
          val newSamplesLeft = samplesLeft - sum(subsetSizes)
          if (newSamplesLeft > 0) {
            recurse(k + 1, newSamplesLeft, acc ++ subsetSizes.map(s => (s, kernelWeight)))
          } else {
            acc ++ subsetSizes.map(s => (s, kernelWeight))
          }
        }
      }
    }

    @tailrec
    def allocateRemainingSamples(k: Int, samplesLeft: Int, acc: Array[Int]): Array[Int] = {
      val rescaledWeights = rescale(weights(k - 1 to -1))
      val paired = k <= numPairedSubset
      val allocation = rescaledWeights(0) * samplesLeft
      val subsetSizes = {
        if (paired) {
          if (allocation >= 1) {
            val half = math.ceil(allocation / 2).toInt
            half :: half :: Nil
          } else {
            1 :: 0 :: Nil
          }
        } else {
          allocation.toInt :: Nil
        }
      }

      val newSamplesLeft = samplesLeft - sum(subsetSizes)
      if (newSamplesLeft > 0) {
        allocateRemainingSamples(k + 1, newSamplesLeft, acc ++ subsetSizes)
      } else {
        acc ++ subsetSizes
      }
    }

    val result = recurse(1, numSamples, Array.empty)

    val remainingSamples = numSamples - sum(result.map(_._1))

    if (remainingSamples == 0) {
      result
    } else {
      result ++ allocateRemainingSamples(result.length / 2 + 1, remainingSamples, Array.empty).map(s => (s, 1.0))
    }
  }