def load_weights()

in arctic_inference/vllm/spec_dec/arctic_speculator.py [0:0]


    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        weights = collections.OrderedDict(weights)
        if self.method == "sum_lstm" and self.tie_lstm_embs:
            weights.pop("input_emb.0.weight")
            weights.pop("cell_emb.0.weight")
            weights.pop("output_emb.0.weight")
            for name, param in self.named_parameters():
                if "projs." in name:
                    print(f"REPLACING {name}")
                    forget_proj = weights.pop(
                        name.replace("projs", "forget_proj"))
                    input_proj = weights.pop(
                        name.replace("projs", "input_proj"))
                    output_proj = weights.pop(
                        name.replace("projs", "output_proj"))
                    cell_proj = weights.pop(name.replace("projs", "cell_proj"))
                    weights[name] = torch.cat(
                        [forget_proj, input_proj, output_proj, cell_proj])

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights.items():
            print(f"LOADING {name}")
            name = name.replace("speculator.", "")
            param = params_dict.get(name)
            self.maybe_load_weight(param, loaded_weight)

            if name.startswith("head"):
                param = params_dict.get(name.replace("head", "qhead"))
                self.maybe_load_weight(param, loaded_weight)