in cuda/lltm_cuda.cpp [47:76]
std::vector<torch::Tensor> lltm_backward(
torch::Tensor grad_h,
torch::Tensor grad_cell,
torch::Tensor new_cell,
torch::Tensor input_gate,
torch::Tensor output_gate,
torch::Tensor candidate_cell,
torch::Tensor X,
torch::Tensor gate_weights,
torch::Tensor weights) {
CHECK_INPUT(grad_h);
CHECK_INPUT(grad_cell);
CHECK_INPUT(input_gate);
CHECK_INPUT(output_gate);
CHECK_INPUT(candidate_cell);
CHECK_INPUT(X);
CHECK_INPUT(gate_weights);
CHECK_INPUT(weights);
return lltm_cuda_backward(
grad_h,
grad_cell,
new_cell,
input_gate,
output_gate,
candidate_cell,
X,
gate_weights,
weights);
}