in fastchat/serve/model_worker.py [0:0]
def get_embeddings(self, params):
self.call_ct += 1
try:
tokenizer = self.tokenizer
ret = {"embedding": [], "token_num": 0}
model_type_dict = {
"is_llama": "llama" in str(type(self.model)),
"is_t5": "t5" in str(type(self.model)),
"is_chatglm": "chatglm" in str(type(self.model)),
"is_bert": "bert" in str(type(self.model)),
"is_robert": "robert" in str(type(self.model)),
}
if self.embed_in_truncate:
encoding = tokenizer.batch_encode_plus(
params["input"],
padding=True,
truncation="longest_first",
return_tensors="pt",
max_length=self.context_len,
)
else:
encoding = tokenizer.batch_encode_plus(
params["input"], padding=True, return_tensors="pt"
)
input_ids = encoding["input_ids"].to(self.device)
attention_mask = input_ids != tokenizer.pad_token_id
base64_encode = params.get("encoding_format", None)
if self.embed_in_truncate:
chunk_embeddings, token_num = self.__process_embed_chunk(
input_ids, attention_mask, **model_type_dict
)
embedding = chunk_embeddings / token_num
normalized_embeddings = F.normalize(embedding, p=2, dim=1)
ret["token_num"] = token_num
else:
all_embeddings = []
all_token_num = 0
for i in range(0, input_ids.size(1), self.context_len):
chunk_input_ids = input_ids[:, i : i + self.context_len]
chunk_attention_mask = attention_mask[:, i : i + self.context_len]
chunk_embeddings, token_num = self.__process_embed_chunk(
chunk_input_ids, chunk_attention_mask, **model_type_dict
)
all_embeddings.append(chunk_embeddings)
all_token_num += token_num
all_embeddings_tensor = torch.stack(all_embeddings)
embedding = torch.sum(all_embeddings_tensor, dim=0) / all_token_num
normalized_embeddings = F.normalize(embedding, p=2, dim=1)
ret["token_num"] = all_token_num
if base64_encode == "base64":
out_embeddings = self.__encode_base64(normalized_embeddings)
else:
out_embeddings = normalized_embeddings.tolist()
ret["embedding"] = out_embeddings
gc.collect()
torch.cuda.empty_cache()
if self.device == "xpu":
torch.xpu.empty_cache()
if self.device == "npu":
torch.npu.empty_cache()
except torch.cuda.OutOfMemoryError as e:
ret = {
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
"error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
}
except (ValueError, RuntimeError) as e:
ret = {
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
"error_code": ErrorCode.INTERNAL_ERROR,
}
return ret