def forward()

in tt_embeddings_ops.py [0:0]


    def forward(
        ctx,
        B: int,
        D: int,
        tt_p_shapes: List[int],
        tt_q_shapes: List[int],
        tt_ranks: List[int],
        L: torch.Tensor,
        nnz_tt: int,
        nnz_cached: int,
        indices: torch.Tensor,
        rowidx: torch.Tensor,
        tableidx: torch.Tensor,
        optimizer: OptimType,
        learning_rate: float,
        eps: float,
        sparse: bool,
        cache_locations: torch.Tensor,
        cache_optimizer_state: torch.Tensor,
        cache_weight: torch.Tensor,
        optimizer_state: List[torch.Tensor],
        *tt_cores: Tuple[torch.Tensor],