google/generativeai/notebook/gspread_client.py (124 lines of code) (raw):
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module that holds a global gspread.client.Client."""
from __future__ import annotations
import abc
import datetime
from typing import Any, Callable, Mapping, Sequence
from google.auth import credentials
from google.generativeai.notebook import html_utils
from google.generativeai.notebook import ipython_env
from google.generativeai.notebook import sheets_id
# The code may be running in an environment where the gspread library has not
# been installed.
_gspread_import_error: Exception | None = None
try:
# pylint: disable-next=g-import-not-at-top
import gspread
except ImportError as e:
_gspread_import_error = e
gspread = None
# Base class of exceptions that gspread.open(), open_by_url() and open_by_key()
# may throw.
GSpreadException = Exception if gspread is None else gspread.exceptions.GSpreadException # type: ignore
class SpreadsheetNotFoundError(RuntimeError):
pass
def _get_import_error() -> Exception:
return RuntimeError('"gspread" module not imported, got: {}'.format(_gspread_import_error))
class GSpreadClient(abc.ABC):
"""Wrapper around gspread.client.Client.
This adds a layer of indirection for us to inject mocks for testing.
"""
@abc.abstractmethod
def validate(self, sid: sheets_id.SheetsIdentifier) -> None:
"""Validates that `name` is the name of a Google Sheets document.
Raises an exception if false.
Args:
sid: The identifier for the document.
"""
@abc.abstractmethod
def get_all_records(
self,
sid: sheets_id.SheetsIdentifier,
worksheet_id: int,
) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]:
"""Returns all records for a Google Sheets worksheet."""
@abc.abstractmethod
def write_records(
self,
sid: sheets_id.SheetsIdentifier,
rows: Sequence[Sequence[Any]],
) -> None:
"""Writes results to a new worksheet to the Google Sheets document."""
class GSpreadClientImpl(GSpreadClient):
"""Concrete implementation of GSpreadClient."""
def __init__(self, client: Any, env: ipython_env.IPythonEnv | None):
"""Constructor.
Args:
client: Instance of gspread.client.Client.
env: Optional instance of IPythonEnv. This is used to display messages
such as the URL of the output Worksheet.
"""
self._client = client
self._ipython_env = env
def _open(self, sid: sheets_id.SheetsIdentifier):
"""Opens a Sheets document from `sid`.
Args:
sid: The identifier for the Sheets document.
Raises:
SpreadsheetNotFoundError: If the Sheets document cannot be found or
cannot be opened.
Returns:
A gspread.Worksheet instance representing the worksheet referred to by
`sid`.
"""
try:
if sid.name():
return self._client.open(sid.name())
if sid.key():
return self._client.open_by_key(str(sid.key()))
if sid.url():
return self._client.open_by_url(str(sid.url()))
except GSpreadException as exc:
raise SpreadsheetNotFoundError("Unable to find Sheets with {}".format(sid)) from exc
raise SpreadsheetNotFoundError("Invalid sheets_id.SheetsIdentifier")
def validate(self, sid: sheets_id.SheetsIdentifier) -> None:
self._open(sid)
def get_all_records(
self,
sid: sheets_id.SheetsIdentifier,
worksheet_id: int,
) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]:
sheet = self._open(sid)
worksheet = sheet.get_worksheet(worksheet_id)
if self._ipython_env is not None:
env = self._ipython_env
def _display_fn():
env.display_html(
"Reading inputs from worksheet {}".format(
html_utils.get_anchor_tag(
url=sheets_id.SheetsURL(worksheet.url),
text="{} in {}".format(worksheet.title, sheet.title),
)
)
)
else:
def _display_fn():
print("Reading inputs from worksheet {} in {}".format(worksheet.title, sheet.title))
return worksheet.get_all_records(), _display_fn
def write_records(
self,
sid: sheets_id.SheetsIdentifier,
rows: Sequence[Sequence[Any]],
) -> None:
sheet = self._open(sid)
# Create a new Worksheet.
# `title` has to be carefully constructed: some characters like colon ":"
# will not work with gspread in Worksheet.append_rows().
current_datetime = datetime.datetime.now()
title = f"Results {current_datetime:%Y_%m_%d} ({current_datetime:%s})"
# append_rows() will resize the worksheet as needed, so `rows` and `cols`
# can be set to 1 to create a worksheet with only a single cell.
worksheet = sheet.add_worksheet(title=title, rows=1, cols=1)
worksheet.append_rows(values=rows)
if self._ipython_env is not None:
self._ipython_env.display_html(
"Results written to new worksheet {}".format(
html_utils.get_anchor_tag(
url=sheets_id.SheetsURL(worksheet.url),
text="{} in {}".format(worksheet.title, sheet.title),
)
)
)
else:
print("Results written to new worksheet {} in {}".format(worksheet.title, sheet.title))
class NullGSpreadClient(GSpreadClient):
"""Null-object implementation of GSpreadClient.
This class raises an error if any of its methods are called. It is used when
the gspread library is not available.
"""
def validate(self, sid: sheets_id.SheetsIdentifier) -> None:
raise _get_import_error()
def get_all_records(
self,
sid: sheets_id.SheetsIdentifier,
worksheet_id: int,
) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]:
raise _get_import_error()
def write_records(
self,
sid: sheets_id.SheetsIdentifier,
rows: Sequence[Sequence[Any]],
) -> None:
raise _get_import_error()
# Global instance of gspread client.
_gspread_client: GSpreadClient | None = None
def authorize(creds: credentials.Credentials, env: ipython_env.IPythonEnv | None) -> None:
"""Sets up credential for gspreads."""
global _gspread_client
if gspread is not None:
client = gspread.authorize(creds) # type: ignore
_gspread_client = GSpreadClientImpl(client=client, env=env)
else:
_gspread_client = NullGSpreadClient()
def get_client() -> GSpreadClient:
if not _gspread_client:
raise RuntimeError("Must call authorize() first")
return _gspread_client
def testonly_set_client(client: GSpreadClient) -> None:
"""Overrides the global client for testing."""
global _gspread_client
_gspread_client = client