def apply_scaling()

in text-generation-inference/server/text_generation_server/jetstream_pt_support/models/llama_model_exportable_hf.py [0:0]


def apply_scaling(freqs: torch.Tensor, config: RopeScalingArgs):
    # Values obtained from grid search
    scale_factor = config.factor
    low_freq_factor = config.low_freq_factor
    high_freq_factor = config.high_freq_factor
    old_context_len = config.original_max_position_embeddings

    low_freq_wavelen = old_context_len / low_freq_factor
    high_freq_wavelen = old_context_len / high_freq_factor
    new_freqs = []
    for freq in freqs:
        wavelen = 2 * math.pi / freq
        if wavelen < high_freq_wavelen:
            new_freqs.append(freq)
        elif wavelen > low_freq_wavelen:
            new_freqs.append(freq / scale_factor)
        else:
            assert low_freq_wavelen != high_freq_wavelen
            smooth = (old_context_len / wavelen - low_freq_factor) / (
                  high_freq_factor - low_freq_factor
            )
            new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
    return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)