def backward()

in tt_embeddings_ops.py [0:0]


    def backward(ctx, d_output: torch.Tensor) -> Tuple[torch.Tensor]:
        (
            L,
            indices,
            rowidx,
            tableidx,
            cache_locations,
            cache_optimizer_state,
            cache_weight,
        ) = ctx.saved_tensors
        batch_count = 1000
        if ctx.sparse:
            if ctx.optimizer in [OptimType.SGD, OptimType.EXACT_SGD]:
                # pyre-fixme[16]
                tt_embeddings.tt_sgd_backward(
                    batch_count,
                    ctx.D,
                    ctx.learning_rate,
                    ctx.tt_p_shapes,
                    ctx.tt_q_shapes,
                    ctx.tt_ranks,
                    L,
                    ctx.nnz_tt,
                    indices,
                    rowidx,
                    tableidx,
                    d_output,
                    list(ctx.tt_cores),
                )
                if ctx.nnz_cached > 0:
                    # pyre-fixme[16]
                    tt_embeddings.cache_backward_sgd(
                        ctx.nnz_cached,
                        d_output,
                        cache_locations[ctx.nnz_tt :],
                        rowidx[ctx.nnz_tt :],
                        ctx.learning_rate,
                        cache_weight,
                    )
            else:
                # pyre-fixme[16]
                tt_embeddings.tt_adagrad_backward(
                    batch_count,
                    ctx.D,
                    ctx.learning_rate,
                    ctx.eps,
                    ctx.tt_p_shapes,
                    ctx.tt_q_shapes,
                    ctx.tt_ranks,
                    L,
                    ctx.nnz_tt,
                    indices,
                    rowidx,
                    tableidx,
                    d_output,
                    ctx.optimizer_state,
                    list(ctx.tt_cores),
                )
                if ctx.nnz_cached > 0:
                    # pyre-fixme[16]
                    tt_embeddings.cache_backward_rowwise_adagrad_approx(
                        ctx.nnz_cached,
                        d_output,
                        cache_locations[ctx.nnz_tt :],
                        rowidx[ctx.nnz_tt :],
                        ctx.learning_rate,
                        ctx.eps,
                        cache_optimizer_state,
                        cache_weight,
                    )
            # pyre-fixme[7]
            return tuple(
                [
                    None,  # D
                    None,  # tt_p_shapes
                    None,  # tt_q_shapes
                    None,  # tt_ranks
                    None,  # K
                    None,  # nnz_tt
                    None,  # nnz_cached
                    None,  # indices
                    None,  # offsets
                    None,  # rowidx
                    None,  # tableidx
                    None,  # optimizer
                    None,  # learning_rate
                    None,  # eps
                    None,  # sparse
                    None,  # cache_locations
                    None,  # cache_optimizer_state
                    None,  # cache_weight
                    None,  # optimizer_state
                ]
                + [None] * len(ctx.tt_cores)
            )
        else:
            # pyre-fixme[16]
            d_tt_cores = tt_embeddings.tt_dense_backward(
                batch_count,
                ctx.D,
                ctx.tt_p_shapes,
                ctx.tt_q_shapes,
                ctx.tt_ranks,
                L,
                ctx.nnz_tt,
                indices,
                rowidx,
                tableidx,
                d_output,
                list(ctx.tt_cores),
            )
            if ctx.nnz_cached > 0:
                # pyre-fixme[16]
                d_cache_weight = tt_embeddings.cache_backward_dense(
                    ctx.nnz_cached,
                    d_output,
                    cache_locations[ctx.nnz_tt :],
                    rowidx[ctx.nnz_tt :],
                    ctx.learning_rate,
                    cache_weight,
                )
            else:
                d_cache_weight = None
            # pyre-fixme[7]
            return tuple(
                [
                    None,  # D
                    None,  # tt_p_shapes
                    None,  # tt_q_shapes
                    None,  # tt_ranks
                    None,  # K
                    None,  # nnz_tt
                    None,  # nnz_cached
                    None,  # indices
                    None,  # offsets
                    None,  # rowidx
                    None,  # tableidx
                    None,  # optimizer
                    None,  # learning_rate
                    None,  # eps
                    None,  # sparse
                    None,  # cache_locations
                    None,  # cache_optimizer_state
                    d_cache_weight,  # cache_weight
                    None,  # optimizer_state
                ]
                + d_tt_cores
            )