src/utils/interactive/completers.py (184 lines of code) (raw):
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""prompt_toolkit completers"""
import logging
import subprocess
import pathlib
from typing import Any, Iterable, List, Optional
from prompt_toolkit.completion import Completer, Completion, WordCompleter
import googleapiclient.discovery
from google.cloud import bigquery
from google.cloud import storage
from google.auth.credentials import Credentials
from googleapiclient.errors import HttpError
class GCPProjectCompleter(Completer):
"""Completer with GCP projects"""
def __init__(self, credentials: Optional[Credentials] = None):
self.service_v1 = googleapiclient.discovery.build(
"cloudresourcemanager",
"v1",
credentials=credentials,
cache_discovery=False)
self.use_cli = False
def get_completions(self, document, complete_event) -> Iterable[Completion]:
return self.get_projects_completions(document.current_line,
document.cursor_position)
def get_projects_completions(self,
filter_str: str,
cursor_position: int) -> Iterable[Completion]:
if len(filter_str) <= 1:
return []
if not self.use_cli:
try:
logging.disable(logging.ERROR)
projects = self.get_projects(filter_str)
has_projects = False
for p in projects:
has_projects = True
yield Completion(p,
start_position=-cursor_position)
if not has_projects:
try:
project = (self.service_v1.projects().get(
projectId=filter_str).execute())
if project:
yield Completion(filter_str,
start_position=-cursor_position)
except: # pylint:disable=bare-except
# No project or access denied - expected
pass
except HttpError as ex:
# If Resource Manager API is not enabled, try using gcloud CLI
if ex.status_code == 403:
self.use_cli = True
projects = self.get_projects_completions(filter_str,
cursor_position)
for p in projects:
yield p
else:
raise
finally:
logging.disable(logging.NOTSET)
else:
projects = self.get_projects_cli(filter_str)
for p in projects:
yield Completion(p,
start_position=-cursor_position)
def get_projects(self, filter_str: Optional[str] = None) -> Iterable[Any]:
if filter_str is not None:
query = (f"projectId:*{filter_str}* name:*{filter_str}* "
"lifecycleState:ACTIVE")
else:
query = "lifecycleState:ACTIVE"
page_results = self.service_v1.projects().list(filter=query).execute()
if page_results and "projects" in page_results:
return [p["projectId"] for p in page_results["projects"]]
else:
return []
def get_projects_cli(self, filter_str: Optional[str] = None) -> List[str]:
if filter_str is not None:
query = (f"projectId:*{filter_str}* name:*{filter_str}* "
"lifecycleState:ACTIVE")
else:
query = "lifecycleState:ACTIVE"
gcloud_result = subprocess.run(("gcloud projects list --filter="
f"'{query}' "
"--format='value(projectId)' "
"--limit=100 --quiet"),
stdout=subprocess.PIPE,
shell=True,
check=False)
if (gcloud_result and gcloud_result.stdout
and gcloud_result.returncode == 0):
projects_str = gcloud_result.stdout.decode("utf-8").strip()
projects = [p.strip() for p in projects_str.split("\n")]
projects = list(filter(lambda p: len(p) > 0, projects))
return projects
else:
return []
def get_project_number(self, project_id: str):
project_result = (self.service_v1.projects().get(
projectId=project_id).execute())
return project_result["projectNumber"]
class BigQueryDatasetCompleter(WordCompleter):
"""Completer with BigQuery Tables"""
def __init__(self, project_id: str,
client: Optional[bigquery.Client] = None):
self.project = project_id
self.client = client
try:
logging.disable(logging.ERROR)
if self.client is None:
self.client = bigquery.Client(self.project)
super().__init__(
words=[
d.reference.dataset_id
for d in self.client.list_datasets(self.project)
],
ignore_case=True,
match_middle=True,
)
finally:
logging.disable(logging.NOTSET)
class StorageBucketCompleter(Completer):
"""Completer with Storage Buckets"""
def __init__(self,
project_id: Optional[str] = None,
client: Optional[storage.Client] = None):
if client is None:
self.client = storage.Client(project_id)
else:
self.client = client
def get_completions(self, document, complete_event):
del complete_event
filter_str = document.current_line
if len(filter_str) > 1:
try:
logging.disable(logging.ERROR)
buckets = self.get_buckets(filter_str)
finally:
logging.disable(logging.NOTSET)
for p in buckets:
yield Completion(p.name,
start_position=-document.cursor_position)
def get_buckets(self, filter_str: Optional[str] = None) -> Iterable[Any]:
page_results = self.client.list_buckets(
prefix=filter_str,
page_size=100,
)
return page_results
class RegionsCompleter(WordCompleter):
"""Completer with GCP regions"""
def __init__(self,
include_multiregions = True,
source_region: Optional[str] = None):
self.regions = []
regions_file = pathlib.Path(__file__).parent.joinpath("bq_regions.txt")
regions_str = regions_file.read_text("utf-8")
regions_lines = [line.lower().strip()
for line in regions_str.split("\n")]
# Filter out empty lines
regions = filter(lambda line: len(line) > 0, regions_lines)
# Exclude `us`` and `eu` if not needed.
if not include_multiregions:
regions = filter(lambda r: r != "us" and r != "eu", regions)
# Sort.
all_regions: list[str] = list(sorted(regions))
if source_region is None or source_region == "":
self.regions = all_regions
else:
location = source_region.lower()
self.regions = list(
filter(lambda r: r.startswith(location), all_regions))
super().__init__(words=self.regions,
ignore_case=True,
match_middle=False)
class ServiceAccountsCompleter(WordCompleter):
"""Completer with Service Accounts"""
def __init__(self,
project_id: str,
credentials: Optional[Credentials] = None):
try:
logging.disable(logging.ERROR)
service = googleapiclient.discovery.build("iam",
"v1",
credentials=credentials)
service_accounts = (service.projects().serviceAccounts().list(
name="projects/" + project_id).execute())
if "accounts" in service_accounts:
self.accounts = [a["email"]
for a in service_accounts["accounts"]]
else:
self.accounts = []
super().__init__(words=self.accounts,
ignore_case=True,
match_middle=False)
finally:
logging.disable(logging.NOTSET)