in tt_embeddings_ops.py [0:0]
def forward(
ctx,
B: int,
D: int,
tt_p_shapes: List[int],
tt_q_shapes: List[int],
tt_ranks: List[int],
L: torch.Tensor,
nnz_tt: int,
nnz_cached: int,
indices: torch.Tensor,
rowidx: torch.Tensor,
tableidx: torch.Tensor,
optimizer: OptimType,
learning_rate: float,
eps: float,
sparse: bool,
cache_locations: torch.Tensor,
cache_optimizer_state: torch.Tensor,
cache_weight: torch.Tensor,
optimizer_state: List[torch.Tensor],
*tt_cores: Tuple[torch.Tensor],