notebooks/util/notebook_functions.py (25 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
def split_list(l: list, split_size: int) -> List[list]:
"""Split list l into sublists of max size split_size."""
return_list = []
for i in range(0, len(l), split_size):
return_list.append(l[i : i + split_size])
return return_list
def remove_unexpected_spanner_primary_keys(
spanner_pk_dict: dict, expected_names: List[str]
):
"""
Mutates spanner_pk_dict to removed unexpected entries, i.e. name NOT IN expected_names.
"""
unknown_tables = [_ for _ in spanner_pk_dict if _ not in expected_names]
if unknown_tables:
print(f"Removing primary key entries for unknown tables: {unknown_tables}")
for to_remove in unknown_tables:
del spanner_pk_dict[to_remove]
def update_spanner_primary_keys(
spanner_pk_dict: dict, table_name: str, source_pk_columns: list
):
"""
Mutates spanner_pk_dict to reflect table_name: source_pk_columns.
Also corrects case of table_name key is the user already provided a column list.
"""
matching_key = [_ for _ in spanner_pk_dict if _.upper() == table_name.upper()]
if matching_key:
# Ensure user provided table_name is of correct case.
matching_key = matching_key[0]
if matching_key != table_name:
spanner_pk_dict[table_name] = spanner_pk_dict[matching_key]
del spanner_pk_dict[matching_key]
else:
spanner_pk_dict[table_name] = ",".join(source_pk_columns or "")