Sources/TensorUtils/LogitsWarper/RepetitionPenaltyWarper.swift (18 lines of code) (raw):
import Foundation
/// `RepetitionPenaltyWarper` prevents the repetition of previous tokens through a penalty.
/// This penalty is applied at most once per token.
/// https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L294
public struct RepetitionPenaltyWarper: LogitsWarper {
public var penalty: Float
public init(penalty: Double) {
self.penalty = Float(penalty)
}
public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) {
var logits = logits
for index in indices {
if logits[index] < 0 {
logits[index] *= penalty
} else {
logits[index] /= penalty
}
}
return (indices, logits)
}
}