in candle-pyo3/quant-llama.py [0:0]
def main():
if len(sys.argv) < 2:
raise ValueError("missing weight file argument")
filename = sys.argv[1]
print(f"reading model file {filename}")
if filename.endswith("gguf"):
all_tensors, metadata = utils.load_gguf(filename)
vocab = metadata["tokenizer.ggml.tokens"]
for i, v in enumerate(vocab):
vocab[i] = "\n" if v == "<0x0A>" else v.replace("▁", " ")
hparams = {k: v for (k, v) in metadata.items() if not k.startswith("tokenizer")}
print(hparams)
hparams = {
"n_vocab": len(vocab),
"n_embd": metadata["llama.embedding_length"],
"n_mult": 256,
"n_head": metadata["llama.attention.head_count"],
"n_head_kv": metadata["llama.attention.head_count_kv"],
"n_layer": metadata["llama.block_count"],
"n_rot": metadata["llama.rope.dimension_count"],
"rope_freq": metadata.get("llama.rope.freq_base", 10000.0),
"ftype": metadata["general.file_type"],
"context_length": metadata["llama.context_length"],
}
all_tensors = {gguf_rename(k): v for k, v in all_tensors.items()}
else:
all_tensors, hparams, vocab = utils.load_ggml(filename)
hparams["context_length"] = 2048
print(hparams)
model = QuantizedLlama(hparams, all_tensors)
print("model built, starting inference")
tokens = [1]
for token_idx in range(500):
last_token = tokens[-1]
lt = candle.tensor([last_token]).unsqueeze(0)
logits = model.forward(lt, len(tokens))
# Greedy sampling for now
# pr = candle.nn.softmax(logits, -1)
m = logits.get(0).argmax_keepdim(-1)
next_token = m.values()[0]
print(vocab[next_token], end="", flush=True)
tokens.append(next_token)