in dualpipe/dualpipe.py [0:0]
def _reset_states(self) -> None:
WeightGradStore.clear()
self.input_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
self.output_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
self.input_grad_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
self.output_grad_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
self.labels: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = None
self.loss_chunks: List[torch.Tensor] = []
self.criterion: Callable = None
self.current_f_chunk_id: List[int] = [0, 0]
self.current_b_chunk_id: List[int] = [0, 0]
self.current_send_f_chunk_id: List[int] = [0, 0]
self.current_send_b_chunk_id: List[int] = [0, 0]
self.current_recv_f_chunk_id: List[int] = [0, 0]
self.current_recv_b_chunk_id: List[int] = [0, 0]
self.comm_ops: List[dist.P2POp] = []
self.to_free: List[torch.Tensor] = []