google/generativeai/operations.py (92 lines of code) (raw):

# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import functools from typing import Iterator from google.generativeai import protos from google.generativeai import client as client_lib from google.generativeai.types import model_types from google.api_core import operation as operation_lib import tqdm.auto as tqdm def list_operations(*, client=None) -> Iterator[CreateTunedModelOperation]: """Calls the API to list all operations""" if client is None: client = client_lib.get_default_operations_client() # The client returns an iterator of Operation protos (`Iterator[google.longrunning.operations_pb2.Operation]`) # not a gapic Operation object (`google.api_core.operation.Operation`) operations = ( CreateTunedModelOperation.from_proto(op, client) for op in client.list_operations(name="", filter_="") ) return operations def get_operation(name: str, *, client=None) -> CreateTunedModelOperation: """Calls the API to get a specific operation""" if client is None: client = client_lib.get_default_operations_client() op = client.get_operation(name=name) return CreateTunedModelOperation.from_proto(op, client) def delete_operation(name: str, *, client=None): """Calls the API to delete a specific operation""" # Raises:google.api_core.exceptions.MethodNotImplemented: Not implemented. if client is None: client = client_lib.get_default_operations_client() return client.delete_operation(name=name) class CreateTunedModelOperation(operation_lib.Operation): @classmethod def from_proto(cls, proto, client): """ result = getattr(proto, 'result', None) if result is not None: if result.value == b'': del proto.result """ return from_gapic( cls=CreateTunedModelOperation, operation=proto, operations_client=client, result_type=protos.TunedModel, metadata_type=protos.CreateTunedModelMetadata, ) @classmethod def from_core_operation( cls, operation: operation_lib.Operation, ): polling = getattr(operation, "_polling", None) retry = getattr(operation, "_retry", None) if polling is not None: # google.api_core v 2.11 kwargs = {"polling": polling} elif retry is not None: # google.api_core v 2.10 kwargs = {"retry": retry} else: kwargs = {} return cls( operation=operation._operation, refresh=operation._refresh, cancel=operation._cancel, result_type=operation._result_type, metadata_type=operation._metadata_type, **kwargs, ) @property def name(self) -> str: return self._operation.name def update(self): """Refresh the current statuses in metadata/result/error""" self._refresh_and_update() def wait_bar(self, **kwargs) -> Iterator[protos.CreateTunedModelMetadata]: """A tqdm wait bar, yields `Operation` statuses until complete. Args: **kwargs: passed through to `tqdm.auto.tqdm(..., **kwargs)` Yields: Operation statuses as `protos.CreateTunedModelMetadata` objects. """ bar = tqdm.tqdm(total=self.metadata.total_steps, initial=0, **kwargs) # done() includes a `_refresh_and_update` while not self.done(): metadata = self.metadata bar.update(self.metadata.completed_steps - bar.n) yield metadata metadata = self.metadata bar.update(self.metadata.completed_steps - bar.n) return self.result() def set_result(self, result: protos.TunedModel): result = model_types.decode_tuned_model(result) super().set_result(result) def from_gapic( cls, *, operation, operations_client, result_type, metadata_type, grpc_metadata=None, **kwargs, ): """`google.api_core.operation.from_gapic`, patched to allow subclasses.""" refresh = functools.partial( operations_client.get_operation, operation.name, metadata=grpc_metadata ) cancel = functools.partial( operations_client.cancel_operation, operation.name, metadata=grpc_metadata, ) return cls(operation, refresh, cancel, result_type, metadata_type, **kwargs)