runinferenceutil/infra.py (101 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import apache_beam as beam import torch from apache_beam.ml.inference.base import RunInference, PredictionResult from transformers import AutoModelForSeq2SeqLM from transformers import AutoTokenizer from datetime import datetime MAX_RESPONSE_TOKENS = 256 model_name = "google/flan-t5-base" tokenizer = AutoTokenizer.from_pretrained(model_name) class FormatForBigquery(beam.DoFn): def process(self, element, side, window=beam.DoFn.WindowParam): ts_format = '%Y-%m-%d %H:%M:%S.%f UTC' window_start = window.start.to_utc_datetime().strftime(ts_format) window_end = window.end.to_utc_datetime().strftime(ts_format) logging.info(f"FormatForBigquery window_start: {window_start}, window_end: {window_end}") for i in side: now = datetime.now() # current date and time date_time = now.strftime("%Y-%m-%d %H:%M:%S") item = element['Yes'] percent, total = element['Yes'][0], element['Yes'][1] prompt = i.get('prompt') return [{ 'time': date_time, 'prompt' : prompt, 'totalMessages': total, 'detectionPercent': percent, }] class FormatForBigqueryMessages(beam.DoFn): def process(self, element): now = datetime.now() # current date and time date_time = now.strftime("%Y-%m-%d %H:%M:%S") prediction, text = element[0], element[1] return [{ 'time': date_time, 'prompt': text, 'modelName' : model_name, 'isDetected': prediction, }] class PercentagesFn(beam.CombineFn): def create_accumulator(self): accumulator = {} return {} def add_input(self, accumulator, input): if input[0] not in accumulator: accumulator[input[0]] = 0 # {'🥕': 0} accumulator[input[0]] += 1 # {'🥕': 1} return accumulator def merge_accumulators(self, accumulators): merged = {} for accum in accumulators: for item, count in accum.items(): if item not in merged: merged[item] = 0 merged[item] += count return merged def extract_output(self, accumulator): total = sum(accumulator.values()) # 10 percentages = {item: [round(count * 100 / total, 2), total] for item, count in accumulator.items()} return percentages class ParDoMerge(beam.DoFn): def process(self, element, side, window=beam.DoFn.WindowParam): ts_format = '%Y-%m-%d %H:%M:%S.%f UTC' window_start = window.start.to_utc_datetime().strftime(ts_format) window_end = window.end.to_utc_datetime().strftime(ts_format) logging.info(f"ParDoMerge window_start: {window_start}, window_end: {window_end}") for i in side: # print(f"Main {e.decode('utf-8')} Side {i}") # print(f"the side input extracted text {i.get('prompt')}") yield i.get('prompt') + '"' + element.decode('utf-8') + '"' def to_bqrequest(e, sql): from apache_beam.io import ReadFromBigQueryRequest yield ReadFromBigQueryRequest(query=sql) def loadstoremodel(): state_dict_path = "saved_model" model_name = "google/flan-t5-base" # Load pre-trained model from hugging face registry or local disk model = AutoModelForSeq2SeqLM.from_pretrained( model_name, torch_dtype=torch.bfloat16 ) #Save Model in local disk torch.save(model.state_dict(), state_dict_path) def to_tensors(input_text: str) -> torch.Tensor: """Encodes input text into token tensors. Args: input_text: Input text for the LLM model. tokenizer: Tokenizer for the LLM model. Returns: Tokenized input tokens. """ return tokenizer.encode_plus(text=input_text, max_length=100, add_special_tokens=True, padding='max_length', return_attention_mask=True, return_token_type_ids=False, return_tensors="pt").input_ids[0] def from_tensors(result: PredictionResult) -> tuple[str, str]: """Decodes output token tensors into text. Args: result: Prediction results from the RunInference transform. tokenizer: Tokenizer for the LLM model. Returns: The model's response as text. """ # PredictionResult input_tokens = result.example decoded_inputs = tokenizer.decode( input_tokens, skip_special_tokens=True) decoded_outputs = tokenizer.decode(result.inference, skip_special_tokens=True) prompt_and_result = f"Input: {decoded_inputs} \t Output: {decoded_outputs}" print(prompt_and_result) logging.info('`runinference` : %s', prompt_and_result) #return decoded_outputs return (decoded_outputs, decoded_inputs) def is_Alert(element, testAlert:list): #if there are Alerts in the current block then write all the messages to the tabel for further investigation if(len(testAlert) > 0) : return element else: return #return nothing def print_hi(name): # Use a breakpoint in the code line below to debug your script. print(f'Hi, {name}') # Press ⌘F8 to toggle the breakpoint. # See PyCharm help at https://www.jetbrains.com/help/pycharm/