in nl2sql/llms/vertexai.py [0:0]
def get_num_tokens(self, text: str) -> int:
"""
Returns the token count for some text
"""
token_struct = struct_pb2.Struct()
token_struct.update({"content": text})
return (
aiplatform_v1beta1.PredictionServiceClient(
client_options={
"api_endpoint": f"{self.location}-aiplatform.googleapis.com"
}
)
.count_tokens(
endpoint=self.client._endpoint_name, # pylint: disable = protected-access
instances=[struct_pb2.Value(struct_value=token_struct)],
)
.total_tokens
)