mdr/retrieval/models/mhop_retriever.py [56:109]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        if args.init_retriever != "":
            print(f"Load pretrained retriever from {args.init_retriever}")
            self.load_retriever(args.init_retriever)

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        self.k = args.k
        self.m = args.m
        self.register_buffer("queue", torch.randn(self.k, config.hidden_size))
        # add layernorm?
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    def load_retriever(self, path):
        state_dict = torch.load(path)
        def filter(x): return x[7:] if x.startswith('module.') else x
        state_dict = {filter(k): v for (k, v) in state_dict.items() if filter(k) in self.encoder_q.state_dict()}
        self.encoder_q.load_state_dict(state_dict)
        return

    @torch.no_grad()
    def momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def dequeue_and_enqueue(self, embeddings):
        """
        memory bank of previous context embeddings, c1 and c2
        """
        # gather keys before updating queue
        batch_size = embeddings.shape[0]
        ptr = int(self.queue_ptr)
        if ptr + batch_size > self.k:
            batch_size = self.k - ptr
            embeddings = embeddings[:batch_size]

        # if self.k % batch_size != 0:
        #     return
        # assert self.k % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[ptr:ptr + batch_size, :] = embeddings

        ptr = (ptr + batch_size) % self.k  # move pointer
        self.queue_ptr[0] = ptr
        return


    def forward(self, batch):
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



mdr/retrieval/models/unified_retriever.py [125:171]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        if args.init_retriever != "":
            print(f"Load pretrained retriever from {args.init_retriever}")
            self.load_retriever(args.init_retriever)

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        self.k = args.k
        self.m = args.m
        self.register_buffer("queue", torch.randn(self.k, config.hidden_size))
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    def load_retriever(self, path):
        state_dict = torch.load(path)
        def filter(x): return x[7:] if x.startswith('module.') else x
        state_dict = {filter(k): v for (k, v) in state_dict.items() if filter(k) in self.encoder_q.state_dict()}
        self.encoder_q.load_state_dict(state_dict)
        return

    @torch.no_grad()
    def momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def dequeue_and_enqueue(self, embeddings):
        """
        memory bank of previous context embeddings, c1 and c2
        """
        # gather keys before updating queue
        batch_size = embeddings.shape[0]
        ptr = int(self.queue_ptr)
        if ptr + batch_size > self.k:
            batch_size = self.k - ptr
            embeddings = embeddings[:batch_size]

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[ptr:ptr + batch_size, :] = embeddings
        ptr = (ptr + batch_size) % self.k  # move pointer
        self.queue_ptr[0] = ptr
        return

    def forward(self, batch):
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



