Sources/TensorUtils/LogitsWarper/TopPLogitsWarper.swift (28 lines of code) (raw):

import Foundation /// Top-P. /// Select the smallest set of elements whose cumulative probability exceeds the probability `p`. /// Based on https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 public struct TopPLogitsWarper: LogitsWarper { public var p: Float public init(p: Float) { self.p = p } public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) { guard !logits.isEmpty else { return (indices: [], logits: []) } let arrSoftmax = Math.softmax(logits) var indexLogitProb = [(index: Int, logit: Float, prob: Float)]() indexLogitProb.reserveCapacity(logits.count) for (index, data) in zip(logits, arrSoftmax).enumerated() { indexLogitProb.append((index: index, logit: data.0, prob: data.1)) } indexLogitProb.sort { $0.prob > $1.prob } let cumsum = Math.cumsum(indexLogitProb.map(\.prob)) var sliceIndex = cumsum.count - 1 for (index, element) in cumsum.enumerated() where element > p { sliceIndex = index break } let toppIndices = indexLogitProb[0...sliceIndex].map { indices[$0.index] } let toppLogits = indexLogitProb[0...sliceIndex].map(\.logit) return (indices: toppIndices, logits: toppLogits) } }