in arctic_inference/vllm/spec_dec/arctic_speculator.py [0:0]
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weights = collections.OrderedDict(weights)
if self.method == "sum_lstm" and self.tie_lstm_embs:
weights.pop("input_emb.0.weight")
weights.pop("cell_emb.0.weight")
weights.pop("output_emb.0.weight")
for name, param in self.named_parameters():
if "projs." in name:
print(f"REPLACING {name}")
forget_proj = weights.pop(
name.replace("projs", "forget_proj"))
input_proj = weights.pop(
name.replace("projs", "input_proj"))
output_proj = weights.pop(
name.replace("projs", "output_proj"))
cell_proj = weights.pop(name.replace("projs", "cell_proj"))
weights[name] = torch.cat(
[forget_proj, input_proj, output_proj, cell_proj])
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights.items():
print(f"LOADING {name}")
name = name.replace("speculator.", "")
param = params_dict.get(name)
self.maybe_load_weight(param, loaded_weight)
if name.startswith("head"):
param = params_dict.get(name.replace("head", "qhead"))
self.maybe_load_weight(param, loaded_weight)