bsp_server/scip_sync_util/scip_utils.py (335 lines of code) (raw):
import hashlib
import json
import multiprocessing
import os
import os.path
import re
import shutil
import tempfile
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import lru_cache
from bsp_server.scip_sync_util import scip_const
from bsp_server.util import utils
def parse_bazelproject(file_path: str) -> dict[str, list[str]]:
"""Parse a .bazelproject file and return a dictionary of the contents."""
data = {}
last_key = ""
with open(file_path, "r") as file:
for line in file:
line = line.strip()
# Skip empty lines and comments
if not line or line.startswith("#"):
continue
key, value = _get_key_value(line)
if key != "":
data[key] = []
last_key = key
if value != "":
data[last_key].append(value)
return data
def _get_key_value(line: str) -> (str, str):
line = line.strip()
key = ""
value = line
if line.startswith("//"):
return key, value
if ":" in line:
k, _, v = line.partition(":")
key = k.strip()
value = v.strip()
return key, value
def get_containing_bazel_target(cwd: str, filepath: str, query_kinds: list[str]) -> str:
from_target = "//" + filepath.rpartition("/src")[0] + "/..."
query_string = (
f'kind("{"|".join(query_kinds)}", rdeps("{from_target}", "{filepath}"))'
)
cmd = [
"bazel",
"query",
query_string,
]
return utils.output(cmd, cwd=cwd)
def old_copy_index(index_to_copy: set[str], dest: str) -> None:
utils.safe_create(dest, is_dir=True)
for index in index_to_copy:
parts = index.split(os.path.sep + "bin" + os.path.sep)
idx = parts[-1]
index_filename = idx.replace("/", "_").replace("-", "_")
dest_path = os.path.join(dest, index_filename)
shutil.copy(index, dest_path)
sha_filename = index_filename + ".sha256"
with open(os.path.join(dest, sha_filename), "w") as f:
f.write(generate_sha256(index) + "\n")
def generate_sha256(file_path: str) -> str:
sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(4096), b""):
sha256.update(byte_block)
return sha256.hexdigest()
def get_sha256_for_file(file_path: str) -> str:
try:
with open(file_path, "r") as f:
return f.read().strip()
except FileNotFoundError:
return None
def get_mnemonic_output(cwd, mnemonic, targets):
"""Verify mnemonic output using a query file."""
union = " + ".join(f'"{target}"' for target in targets)
query = f'mnemonic("{mnemonic}", {union})'
# Create a temporary file for the query
with tempfile.NamedTemporaryFile(
mode="w", suffix=".txt", delete=False
) as query_file:
query_file.write(query)
query_file_path = query_file.name
try:
aquery_cmd = [
"bazel",
"aquery",
"--query_file",
query_file_path,
"--aspects",
scip_const.ASPECT_SCIP_INDEX,
scip_const.ASPECT_OUTPUT_GROUPS,
"--output=jsonproto",
"--keep_going",
]
action_out_json = utils.output(command=aquery_cmd, cwd=cwd)
print(f"Processing action output...")
return _get_all_outputs(action_out_json)
except Exception as e:
return {}
def _get_all_outputs(json_data):
"""Get all outputs from the action json data."""
data = json.loads(json_data)
path_fragments = []
if "pathFragments" in data:
path_fragments = data["pathFragments"]
artifacts = {}
if "artifacts" in data:
artifacts = data["artifacts"]
actions = {}
if "actions" in data:
actions = data["actions"]
targets = []
if "targets" in data:
targets = data["targets"]
# Calculate thread pool size as half of available CPUs
max_workers = get_thread_pool_size()
# Create a manager for thread-safe shared objects
manager = multiprocessing.Manager()
# Pre-process path fragments to create a lookup dictionary for parent fragments
# This avoids repeated searches in the process_fragment function
parent_lookup = {}
for fragment in path_fragments:
parent_id = fragment.get("parentId")
if parent_id:
parent_lookup[fragment["id"]] = parent_id
# Create a dictionary to map fragment IDs to their labels for quick lookup
fragment_labels = {fragment["id"]: fragment["label"] for fragment in path_fragments}
# Create a thread-safe dictionary to map pathFragmentId to the full path
path_dict = manager.dict()
with ThreadPoolExecutor(max_workers=max_workers) as executor:
def process_fragment(fragment):
"""Process a single path fragment to build its full path."""
fragment_id = fragment["id"]
# Check if we've already processed this fragment
if fragment_id in path_dict:
return None
# Build the full path by traversing parent IDs
path_parts = []
current_id = fragment_id
# Collect all path parts by traversing up the parent chain
while current_id:
path_parts.append(fragment_labels[current_id])
current_id = parent_lookup.get(current_id)
# Combine path parts in reverse order (from root to leaf)
full_path = "/".join(reversed(path_parts))
return fragment_id, full_path
# Submit all fragments for processing
future_to_fragment = {
executor.submit(process_fragment, fragment): fragment
for fragment in path_fragments
}
# Collect results as they complete
for future in as_completed(future_to_fragment):
result = future.result()
if result: # Skip None results (already processed fragments)
fragment_id, path = result
path_dict[fragment_id] = path
# Create a dictionary to map artifactId to pathFragmentId
artifact_dict = {
artifact["id"]: artifact["pathFragmentId"] for artifact in artifacts
}
# Optimize the target output dictionary creation
# Group actions by target ID to reduce dictionary updates
target_output_dict = {}
for action in actions:
target_id = action["targetId"]
if target_id not in target_output_dict:
target_output_dict[target_id] = {}
if action["mnemonic"] not in target_output_dict[target_id]:
target_output_dict[target_id][action["mnemonic"]] = []
target_output_dict[target_id][action["mnemonic"]].extend(action["outputIds"])
# Create a thread-safe dictionary for the final output
target_output_paths = manager.dict()
# Batch targets for processing to reduce thread overhead
batch_size = max(1, len(targets) // (max_workers * 2))
target_batches = [
targets[i : i + batch_size] for i in range(0, len(targets), batch_size)
]
def process_target_batch(target_batch):
"""Process a batch of targets to build their output paths."""
batch_results = {}
for target in target_batch:
target_id = target["id"]
target_label = target["label"]
mnemonic_to_output_ids = target_output_dict.get(target_id, {})
target_results = {}
for mnemonic, output_ids in mnemonic_to_output_ids.items():
# Use list comprehension with pre-filtering to improve performance
valid_output_ids = [oid for oid in output_ids if oid in artifact_dict]
output_paths = [
path_dict[artifact_dict[oid]] for oid in valid_output_ids
]
if output_paths: # Only add non-empty results
target_results[mnemonic] = output_paths
if target_results: # Only add targets with results
batch_results[target_label] = target_results
return batch_results
# Process target batches in parallel
with ThreadPoolExecutor(max_workers=max_workers) as executor:
batch_results = list(executor.map(process_target_batch, target_batches))
# Merge all batch results into the final dictionary
final_results = {}
for batch_result in batch_results:
for target_label, mnemonics in batch_result.items():
if target_label not in final_results:
final_results[target_label] = {}
for mnemonic, paths in mnemonics.items():
if mnemonic not in final_results[target_label]:
final_results[target_label][mnemonic] = []
final_results[target_label][mnemonic].extend(paths)
return final_results
@lru_cache(maxsize=1)
def get_thread_pool_size() -> int:
"""Cache the thread pool size calculation."""
return max(1, multiprocessing.cpu_count() // 2)
def copy_index(index_to_copy: set[str], dest: str) -> None:
utils.safe_create(dest, is_dir=True)
def process_and_copy_scip_index(
source_path: str, current_status: dict
) -> tuple[str, str]:
try:
# Process index info
source_sha_path = source_path + scip_const.SHA256_FILE_SUFFIX
relative_path = source_path.split(os.path.sep + "bin" + os.path.sep)[-1]
scip_index_name = relative_path.replace("/", "_").replace("-", "_")
new_sha = get_sha256_for_file(source_sha_path)
# Check if copy needed
if (
scip_index_name in current_status
and current_status[scip_index_name] == new_sha
):
return (
scip_index_name,
scip_index_name + scip_const.SHA256_FILE_SUFFIX,
)
# Copy files if needed
dest_index_path = os.path.join(dest, scip_index_name)
shutil.copy(source_path, dest_index_path)
shutil.copy(
source_sha_path, dest_index_path + scip_const.SHA256_FILE_SUFFIX
)
return (scip_index_name, scip_index_name + scip_const.SHA256_FILE_SUFFIX)
except Exception as e:
print(f"Failed to process index {source_path}: {str(e)}")
return None
def get_current_status(filename: str) -> tuple[str, str]:
if not filename.endswith(".scip"):
return None, None
sha = get_sha256_for_file(
os.path.join(dest, filename + scip_const.SHA256_FILE_SUFFIX)
)
return (filename, sha) if sha else (None, None)
with ThreadPoolExecutor(max_workers=get_thread_pool_size()) as executor:
# Get current status
current_status = dict(
filter(
None,
executor.map(
get_current_status,
[f for f in os.listdir(dest) if f.endswith(".scip")],
),
)
)
# Process and copy files
copy_results = list(
filter(
None,
executor.map(
lambda src: process_and_copy_scip_index(src, current_status),
index_to_copy,
),
)
)
# Delete old files
files_to_keep = {name for pair in copy_results for name in pair}
files_to_delete = (
set(os.listdir(dest)) - files_to_keep - {scip_const.WORKSPACE_FILE_NAME}
)
files_to_delete = {
f
for f in files_to_delete
if not f.startswith(scip_const.JDK_SCIP_FILE_PREFIX)
}
if files_to_delete:
list(
executor.map(
lambda f: (
os.remove(os.path.join(dest, f))
if os.path.isfile(os.path.join(dest, f))
else shutil.rmtree(os.path.join(dest, f))
),
files_to_delete,
)
)
def transform_bazel_query_results(qr: list[dict]) -> dict[str, dict[str, list[str]]]:
"""Transform the results of a bazel query into a dictionary of dependencies."""
res = {}
rules = set([])
for target in qr:
if target["type"] == "RULE":
rules.add(target["rule"]["name"])
for target in qr:
if target["type"] != "RULE":
continue
rule = target["rule"]
name = rule["name"]
base_path = rule["name"].split(":")[0][2:]
# add all rule inputs except external repository
# to deps
direct_deps = []
for dep in rule.get("ruleInput", []):
if dep.startswith("@"):
continue
if dep not in rules:
continue
direct_deps.append(dep)
target_type = rule["ruleClass"]
e_deps = []
deps = []
for attr in rule.get("attribute", []):
if "stringListValue" not in attr:
continue
if attr["name"] == "deps":
deps += [
dep
for dep in attr.get("stringListValue", [])
if not dep.startswith("@") and dep in rules
]
if attr["name"] == "data":
deps += [
data
for data in attr.get("stringListValue", [])
if not data.startswith("@") and data in rules
]
if attr["name"] == "exports":
e_deps = attr.get("stringListValue", [])
info = {
"base_path": base_path,
"deps": list(set(deps)),
"direct_deps": direct_deps,
"exports": e_deps,
"target_type": target_type,
}
res[name] = info
return res
def filter_list_by_regex(list_to_filter: set[str], regex_set: set[str]) -> set[str]:
"""Filters a list based on regex patterns from another list."""
filtered_list = set()
for regex_pattern in regex_set:
for item in list_to_filter:
if re.search(regex_pattern, item):
filtered_list.add(item)
return filtered_list
def dfs(
dep_graph: dict[str, dict[str, list[str]]], target: str, depth: int
) -> list[str]:
"""
Run a depth-first search on a dependency graph. Returns a list of targets.
Will stop at the specified depth. Depth is ignored if the target is exported.
:param dep_graph : A dictionary of dependencies.
:param target: The target to start the search from.
:param depth: The depth to search to.
:return: A list of targets.
"""
result = []
if target not in dep_graph:
return result
for exported_dep in dep_graph[target]["exports"]:
result.append(exported_dep)
result.extend(dfs(dep_graph, exported_dep, depth - 1))
if depth < 0:
return result
result.extend([target])
if depth == 0:
return result
for dep in dep_graph[target]["direct_deps"]:
result.append(dep)
result.extend(dfs(dep_graph, dep, depth - 1))
return result