in tt_embeddings_ops.py [0:0]
def backward(ctx, d_output: torch.Tensor) -> Tuple[torch.Tensor]:
(
L,
indices,
rowidx,
tableidx,
cache_locations,
cache_optimizer_state,
cache_weight,
) = ctx.saved_tensors
batch_count = 1000
if ctx.sparse:
if ctx.optimizer in [OptimType.SGD, OptimType.EXACT_SGD]:
# pyre-fixme[16]
tt_embeddings.tt_sgd_backward(
batch_count,
ctx.D,
ctx.learning_rate,
ctx.tt_p_shapes,
ctx.tt_q_shapes,
ctx.tt_ranks,
L,
ctx.nnz_tt,
indices,
rowidx,
tableidx,
d_output,
list(ctx.tt_cores),
)
if ctx.nnz_cached > 0:
# pyre-fixme[16]
tt_embeddings.cache_backward_sgd(
ctx.nnz_cached,
d_output,
cache_locations[ctx.nnz_tt :],
rowidx[ctx.nnz_tt :],
ctx.learning_rate,
cache_weight,
)
else:
# pyre-fixme[16]
tt_embeddings.tt_adagrad_backward(
batch_count,
ctx.D,
ctx.learning_rate,
ctx.eps,
ctx.tt_p_shapes,
ctx.tt_q_shapes,
ctx.tt_ranks,
L,
ctx.nnz_tt,
indices,
rowidx,
tableidx,
d_output,
ctx.optimizer_state,
list(ctx.tt_cores),
)
if ctx.nnz_cached > 0:
# pyre-fixme[16]
tt_embeddings.cache_backward_rowwise_adagrad_approx(
ctx.nnz_cached,
d_output,
cache_locations[ctx.nnz_tt :],
rowidx[ctx.nnz_tt :],
ctx.learning_rate,
ctx.eps,
cache_optimizer_state,
cache_weight,
)
# pyre-fixme[7]
return tuple(
[
None, # D
None, # tt_p_shapes
None, # tt_q_shapes
None, # tt_ranks
None, # K
None, # nnz_tt
None, # nnz_cached
None, # indices
None, # offsets
None, # rowidx
None, # tableidx
None, # optimizer
None, # learning_rate
None, # eps
None, # sparse
None, # cache_locations
None, # cache_optimizer_state
None, # cache_weight
None, # optimizer_state
]
+ [None] * len(ctx.tt_cores)
)
else:
# pyre-fixme[16]
d_tt_cores = tt_embeddings.tt_dense_backward(
batch_count,
ctx.D,
ctx.tt_p_shapes,
ctx.tt_q_shapes,
ctx.tt_ranks,
L,
ctx.nnz_tt,
indices,
rowidx,
tableidx,
d_output,
list(ctx.tt_cores),
)
if ctx.nnz_cached > 0:
# pyre-fixme[16]
d_cache_weight = tt_embeddings.cache_backward_dense(
ctx.nnz_cached,
d_output,
cache_locations[ctx.nnz_tt :],
rowidx[ctx.nnz_tt :],
ctx.learning_rate,
cache_weight,
)
else:
d_cache_weight = None
# pyre-fixme[7]
return tuple(
[
None, # D
None, # tt_p_shapes
None, # tt_q_shapes
None, # tt_ranks
None, # K
None, # nnz_tt
None, # nnz_cached
None, # indices
None, # offsets
None, # rowidx
None, # tableidx
None, # optimizer
None, # learning_rate
None, # eps
None, # sparse
None, # cache_locations
None, # cache_optimizer_state
d_cache_weight, # cache_weight
None, # optimizer_state
]
+ d_tt_cores
)