in tt_embeddings_ops.py [0:0]
def full_weight(self) -> torch.Tensor:
assert (
self.num_tables == 1
), "full_weight() only supported for num_tables == 1 for now"
return tt_matrix_to_full(
self.tt_p_shapes,
self.tt_q_shapes,
self.tt_ranks,
self.tt_cores,
[1, 0, 2, 3],
)