optimum/habana/transformers/models/vits/modeling_vits.py (41 lines of code) (raw):

import numpy as np import torch from torch import nn from transformers.models.vits.modeling_vits import _rational_quadratic_spline from transformers.utils import logging logger = logging.get_logger(__name__) def gaudi_unconstrained_rational_quadratic_spline( inputs, unnormalized_widths, unnormalized_heights, unnormalized_derivatives, reverse=False, tail_bound=5.0, min_bin_width=1e-3, min_bin_height=1e-3, min_derivative=1e-3, ): """ Copied from _unconstrained_rational_quadratic_spline: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/vits/modeling_vits.py#L126 The only differences are: - WA to fix hpu graph accuracy issue """ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) outside_interval_mask = ~inside_interval_mask outputs = torch.zeros_like(inputs) log_abs_det = torch.zeros_like(inputs) constant = np.log(np.exp(1 - min_derivative) - 1) unnormalized_derivatives = nn.functional.pad(unnormalized_derivatives, pad=(1, 1)) unnormalized_derivatives[..., 0] = constant unnormalized_derivatives[..., -1] = constant outputs[outside_interval_mask] = inputs[outside_interval_mask] log_abs_det[outside_interval_mask] = 0.0 outputs_i, log_abs_det_i = _rational_quadratic_spline( inputs=inputs[inside_interval_mask], unnormalized_widths=unnormalized_widths[inside_interval_mask, :], unnormalized_heights=unnormalized_heights[inside_interval_mask, :], unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], reverse=reverse, tail_bound=tail_bound, min_bin_width=min_bin_width, min_bin_height=min_bin_height, min_derivative=min_derivative, ) outputs = outputs_i * inside_interval_mask + outputs * outside_interval_mask log_abs_det = log_abs_det_i * inside_interval_mask + log_abs_det * outside_interval_mask return outputs, log_abs_det