runinferenceutil/infra.py (101 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import apache_beam as beam
import torch
from apache_beam.ml.inference.base import RunInference, PredictionResult
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer
from datetime import datetime
MAX_RESPONSE_TOKENS = 256
model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
class FormatForBigquery(beam.DoFn):
def process(self, element, side, window=beam.DoFn.WindowParam):
ts_format = '%Y-%m-%d %H:%M:%S.%f UTC'
window_start = window.start.to_utc_datetime().strftime(ts_format)
window_end = window.end.to_utc_datetime().strftime(ts_format)
logging.info(f"FormatForBigquery window_start: {window_start}, window_end: {window_end}")
for i in side:
now = datetime.now() # current date and time
date_time = now.strftime("%Y-%m-%d %H:%M:%S")
item = element['Yes']
percent, total = element['Yes'][0], element['Yes'][1]
prompt = i.get('prompt')
return [{
'time': date_time,
'prompt' : prompt,
'totalMessages': total,
'detectionPercent': percent,
}]
class FormatForBigqueryMessages(beam.DoFn):
def process(self, element):
now = datetime.now() # current date and time
date_time = now.strftime("%Y-%m-%d %H:%M:%S")
prediction, text = element[0], element[1]
return [{
'time': date_time,
'prompt': text,
'modelName' : model_name,
'isDetected': prediction,
}]
class PercentagesFn(beam.CombineFn):
def create_accumulator(self):
accumulator = {}
return {}
def add_input(self, accumulator, input):
if input[0] not in accumulator:
accumulator[input[0]] = 0 # {'🥕': 0}
accumulator[input[0]] += 1 # {'🥕': 1}
return accumulator
def merge_accumulators(self, accumulators):
merged = {}
for accum in accumulators:
for item, count in accum.items():
if item not in merged:
merged[item] = 0
merged[item] += count
return merged
def extract_output(self, accumulator):
total = sum(accumulator.values()) # 10
percentages = {item: [round(count * 100 / total, 2), total] for item, count in accumulator.items()}
return percentages
class ParDoMerge(beam.DoFn):
def process(self, element, side, window=beam.DoFn.WindowParam):
ts_format = '%Y-%m-%d %H:%M:%S.%f UTC'
window_start = window.start.to_utc_datetime().strftime(ts_format)
window_end = window.end.to_utc_datetime().strftime(ts_format)
logging.info(f"ParDoMerge window_start: {window_start}, window_end: {window_end}")
for i in side:
# print(f"Main {e.decode('utf-8')} Side {i}")
# print(f"the side input extracted text {i.get('prompt')}")
yield i.get('prompt') + '"' + element.decode('utf-8') + '"'
def to_bqrequest(e, sql):
from apache_beam.io import ReadFromBigQueryRequest
yield ReadFromBigQueryRequest(query=sql)
def loadstoremodel():
state_dict_path = "saved_model"
model_name = "google/flan-t5-base"
# Load pre-trained model from hugging face registry or local disk
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16
)
#Save Model in local disk
torch.save(model.state_dict(), state_dict_path)
def to_tensors(input_text: str) -> torch.Tensor:
"""Encodes input text into token tensors.
Args:
input_text: Input text for the LLM model.
tokenizer: Tokenizer for the LLM model.
Returns: Tokenized input tokens.
"""
return tokenizer.encode_plus(text=input_text,
max_length=100,
add_special_tokens=True, padding='max_length',
return_attention_mask=True,
return_token_type_ids=False,
return_tensors="pt").input_ids[0]
def from_tensors(result: PredictionResult) -> tuple[str, str]:
"""Decodes output token tensors into text.
Args:
result: Prediction results from the RunInference transform.
tokenizer: Tokenizer for the LLM model.
Returns: The model's response as text.
"""
# PredictionResult
input_tokens = result.example
decoded_inputs = tokenizer.decode(
input_tokens, skip_special_tokens=True)
decoded_outputs = tokenizer.decode(result.inference, skip_special_tokens=True)
prompt_and_result = f"Input: {decoded_inputs} \t Output: {decoded_outputs}"
print(prompt_and_result)
logging.info('`runinference` : %s', prompt_and_result)
#return decoded_outputs
return (decoded_outputs, decoded_inputs)
def is_Alert(element, testAlert:list):
#if there are Alerts in the current block then write all the messages to the tabel for further investigation
if(len(testAlert) > 0) :
return element
else:
return #return nothing
def print_hi(name):
# Use a breakpoint in the code line below to debug your script.
print(f'Hi, {name}') # Press ⌘F8 to toggle the breakpoint.
# See PyCharm help at https://www.jetbrains.com/help/pycharm/