bigquery_etl/format_sql/format.py (95 lines of code) (raw):
"""Format SQL."""
import glob
import os
import os.path
import sys
from functools import partial
from multiprocessing.pool import Pool
from pathlib import Path
from typing import Tuple
from sqlglot.errors import ParseError
from bigquery_etl.config import ConfigLoader
from bigquery_etl.format_sql.formatter import reformat # noqa E402
from bigquery_etl.util.common import qualify_table_references_in_file
def skip_format():
"""Return a list of configured queries for which formatting should be skipped."""
return [
file
for skip in ConfigLoader.get("format", "skip", fallback=[])
for file in glob.glob(skip, recursive=True)
]
def skip_qualifying_references():
"""Return a list of configured queries where fully qualifying references should be skipped."""
return [
file
for skip in ConfigLoader.get(
"format", "skip_qualifying_references", fallback=[]
)
for file in glob.glob(skip, recursive=True)
]
def _format_path(check: bool, path: str) -> Tuple[int, int]:
"""
Format SQL file or if `check` flag set validate it is correct format.
:param check: Flag which indicates whether we should only check if query needs reformatted.
:param path: String path of the SQL file to perform the operations on.
:return: a tuple with two elements. Element one represents if formamtting was applied,
element two if the query's format is invalid.
"""
query = Path(path).read_text()
try:
if not any([path.endswith(s) for s in skip_qualifying_references()]):
fully_referenced_query = qualify_table_references_in_file(Path(path))
else:
fully_referenced_query = query
except NotImplementedError:
fully_referenced_query = query # not implemented for scripts
except ParseError as e:
print(f"Invalid syntax found for {path}: {e}")
return 0, 1
formatted = reformat(fully_referenced_query, trailing_newline=True)
if query != formatted:
if check:
print(f"Needs reformatting: bqetl format {path}")
else:
with open(path, "w") as fp:
fp.write(formatted)
print(f"Reformatted: {path}")
return 1, 0
else:
return 0, 0
def format(paths, check=False, parallelism=8):
"""Format SQL files."""
if not paths:
query = sys.stdin.read()
formatted = reformat(query, trailing_newline=True)
if not check:
print(formatted, end="")
if check and query != formatted:
sys.exit(1)
else:
sql_files = []
for path in paths:
if os.path.isdir(path):
sql_files.extend(
filepath
for dirpath, _, filenames in os.walk(path, followlinks=True)
for filename in filenames
if filename.endswith(".sql")
# skip tests/**/input.sql
and not (path.startswith("tests") and filename == "input.sql")
for filepath in [os.path.join(dirpath, filename)]
if not any([filepath.endswith(s) for s in skip_format()])
)
elif path:
sql_files.append(path)
if not sql_files:
print("Error: no files were found to format")
sys.exit(255)
with Pool(parallelism) as pool:
results = pool.map(partial(_format_path, check), sql_files)
reformatted = sum([x[0] for x in results])
unchanged = len(sql_files) - reformatted
invalid = sum([x[1] for x in results])
print(
", ".join(
f"{number} file{'s' if number > 1 else ''}"
f"{' would be' if check else ''} {msg}"
for number, msg in [
(reformatted, "reformatted"),
(unchanged, "left unchanged"),
(invalid, "invalid"),
]
if number > 0
)
+ "."
)
if check and (reformatted or invalid):
sys.exit(1)