experiments/arena/config/spanner_config.py (176 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spanner table creation script: tables and indexes used by the Arena Study."""
# Prerequisite: Create a Spanner instance "arena_study" on GCP console with 100 processing units.
from dataclasses import dataclass, field, fields
import dataclasses
from datetime import datetime
from enum import Enum
import logging
import secrets
import string
from typing import Optional
from google.cloud import spanner
from utils.logger import LogLevel, log
from config.default import Default
config = Default()
@dataclass
class ArenaModelEvaluation():
"""This class maps 1:1 to DB table 'Study'. IO is handled by Spanner ORM."""
model_name: str = field(default=None) # Default model name if not provided
study: str = field(default=None) # Default study name if not provided
time_of_rating: datetime | None = field(default=None) # Time of rating
rating: float = field(default=1000.0) # Default rating is 1000.0 if not provided
id: str = field(default_factory=lambda: '') # Unique identifier for the study run
def __post_init__(self):
if not isinstance(self.model_name, str) or not self.model_name:
raise ValueError("model_name must be a non-empty string.")
if self.time_of_rating is not None and not isinstance(self.time_of_rating, datetime):
raise ValueError("time_of_rating must be a datetime object.")
if not isinstance(self.rating, (float, int)):
raise ValueError("rating must be a float or int.")
if not isinstance(self.study, str) or not self.study:
raise ValueError("study must be a non-empty string.")
if not isinstance(self.id, str):
raise ValueError("id must be a string.")
log(f"Initialized StudyRun: {self.model_name}, {self.time_of_rating}, {self.rating}, {self.study}, {self.id}")
class ArenaStudyTracker:
"""Arena Study Tracker for managing study runs in Spanner (Singleton)."""
_instance = None
def __new__(cls, project_id: str, spanner_instance_id: str, spanner_database_id: str):
if cls._instance is None:
cls._instance = super(ArenaStudyTracker, cls).__new__(cls)
cls._instance.project_id = project_id
cls._instance.spanner_instance_id = spanner_instance_id
cls._instance.spanner_database_id = spanner_database_id
cls._instance.client = spanner.Client(project=project_id)
cls._instance.instance = cls._instance.client.instance(spanner_instance_id)
cls._instance.database = cls._instance.instance.database(spanner_database_id)
log("ArenaStudyTracker instance created.")
return cls._instance
def _generate_unique_id(self, number_characters: int = 8) -> str:
"""Generate a unique ID of a specified length."""
characters = string.ascii_uppercase + string.ascii_lowercase + string.digits
unique_id = ''.join(secrets.choice(characters) for _ in range(number_characters))
log(f"Generated unique ID: {unique_id}")
return unique_id
def upsert_study_runs(self, study_runs: list[ArenaModelEvaluation], table_name: Optional[str] = "Study"):
"""Adds or updates a list of study runs in the Spanner database."""
inserts = []
updates = []
current_timestamp = spanner.COMMIT_TIMESTAMP
for study_run in study_runs:
is_insert = False
if not study_run.id:
study_run.id = self._generate_unique_id()
is_insert = True
if not study_run.time_of_rating:
study_run.time_of_rating = current_timestamp
is_insert = True
log("Setting time_of_rating to commit timestamp as it was not provided.")
columns = [field.name for field in fields(ArenaModelEvaluation)]
values = []
for field in fields(ArenaModelEvaluation):
value = getattr(study_run, field.name)
if isinstance(value, datetime):
value = value.isoformat()
if isinstance(value, Enum):
value = str(value)
values.append(value)
if is_insert:
inserts.append(values)
else:
updates.append(values)
try:
with self.database.batch() as batch:
if inserts:
log(f"Inserting {len(inserts)} new study runs into the database.")
batch.insert(table_name, columns=columns, values=inserts)
if updates:
log(f"Updating {len(updates)} existing study runs in the database.")
batch.update(table_name, columns=columns, values=updates)
log(f"{len(study_runs)} study runs added/updated successfully in the database.")
except Exception as e:
raise Exception(f"Error adding study runs: {e}") from e
finally:
self._close_connection()
def close(self):
"""Close the Spanner client connection."""
self._close_connection()
def _close_connection(self):
"""Internal method to close the Spanner client connection."""
if self._instance and self._instance.client:
self._instance.client.close()
self._instance.client = None
log("Database connection closed.")
else:
log("Client was already closed or not initialized.", LogLevel.WARNING)
def __del__(self):
"""Destructor to ensure the Spanner client is closed when the object is deleted."""
self.close()
log("ArenaStudyTracker instance deleted and database connection closed.")
def __enter__(self):
"""Enter the runtime context for the ArenaStudyTracker."""
return self
def __exit__(self, exc_type, exc_value, traceback):
"""Exit the runtime context for the ArenaStudyTracker and close the client."""
self.close()
if exc_type:
log(f"Exception occurred: {exc_value}", LogLevel.ERROR)
return False
class ArenaStudySchema():
"""Arena Study Schema"""
def __init__(self, project_id: str, spanner_instance_id: str, spanner_database_id: str):
self.project_id = project_id
self.spanner_instance_id = spanner_instance_id
self.spanner_database_id = spanner_database_id
self.client = spanner.Client(project_id)
self.instance = self.client.instance(self.spanner_instance_id)
self.database = self.instance.database(self.spanner_database_id)
def create_database(self, exists_ok: bool = False):
"""Create Spanner database if it does not exist."""
try:
operation = self.database.create()
# Wait for the operation to complete
log(f"Creating database {self.spanner_database_id}.")
log(f"Adding to Spanner instance {self.spanner_instance_id} in project {self.project_id}.")
operation.result(config.SPANNER_TIMEOUT)
log("Database created successfully.")
except Exception as e:
if "Database already exists" in str(e) and exists_ok:
log("Database already exists, skipping creation.")
return
raise Exception(f"Error creating database: {e}") from e
def create_study_table(self):
"""Create study table"""
try:
log("Creating study table.")
operation = self.database.update_ddl([
"""
CREATE TABLE Study (
model_name STRING(MAX) NOT NULL,
time_of_rating TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true),
rating FLOAT64 NOT NULL,
study STRING(MAX) NOT NULL,
id STRING(MAX) NOT NULL
)
PRIMARY KEY (study, model_name, time_of_rating)
"""])
operation.result(config.SPANNER_TIMEOUT)
log("Study table created successfully.")
except Exception as e:
raise Exception(f"Error creating study table: {e}") from e
def create_study_index(self):
"""Create study index"""
try:
log("Creating study index.")
# To query ratings for a specific model across all studies
operation = self.database.update_ddl([
"""
CREATE INDEX StudyByModel ON Study(model_name)
"""
])
operation.result(config.SPANNER_TIMEOUT)
log("Study index created successfully.")
except Exception as e:
raise Exception(f"Error creating study index: {e}") from e
finally:
self.client.close()
log("Database connection closed.")
def create_schema(self):
"""Create schema"""
try:
log("Creating schema.")
self.create_study_table()
self.create_study_index()
log("Schema created successfully.")
except Exception as e:
raise Exception(f"Error creating schema: {e}") from e
finally:
self.client.close()
log("Database connection closed.")
def drop_schema(self):
"""Drop schema"""
try:
log("Dropping schema.")
operation = self.database.update_ddl(["DROP INDEX StudyByModel", "DROP TABLE Study"])
operation.result(config.SPANNER_TIMEOUT)
log("Schema dropped successfully.")
except Exception as e:
raise Exception(f"Error dropping schema: {e}") from e
finally:
self.client.close()
log("Database connection closed.")
if __name__ == "__main__":
# Create an instance of the ArenaStudySchema class
schema = ArenaStudySchema(
project_id=config.PROJECT_ID,
spanner_instance_id=config.SPANNER_INSTANCE_ID,
spanner_database_id=config.SPANNER_DATABASE_ID
)
# TODO: DEPLOYMENT - Uncomment the following line to create the schema and tables
# schema.create_schema()
# schema.drop_schema()