scripts/sync_nullability_annotations.py (97 lines of code) (raw):
#!/usr/bin/env fbpython
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import pathlib
import re
import subprocess
import sys
from typing import List, Optional
def find_m_file_for_h_file(h_file_path, all_m_files):
# Check in the same folder
m_file_path = h_file_path.replace(".h", ".m")
if os.path.exists(m_file_path):
return m_file_path
# Check in the whole project
m_file_basename = os.path.basename(h_file_path).replace(".h", ".m")
# How many implementations file have this name?
matches = [f for f in all_m_files if f.endswith(m_file_basename)]
matches_count = len(matches)
if matches_count == 1:
return matches[0]
# Finding 0 matches is fine since many header files won't match a corresponding m file, ex TestTools-Bridging-Header.h
# but finding more than one match would be weird
if matches_count > 1:
print(f"More than one match found for {m_file_basename}. Matches: {matches}")
return None
def main():
base_dir = determine_base_dir()
os.chdir(base_dir)
all_h_files = get_files_with_extension("*.h")
all_m_files = get_files_with_extension("*.m")
print(f"Found {len(all_h_files)} '*.h' files and {len(all_m_files)} '*.m' files")
for h_file_path in all_h_files:
# key: method name with newlines and nullable removed
# value: the original method
method_map = {}
m_file_path = find_m_file_for_h_file(h_file_path, all_m_files)
if not m_file_path:
# Many .h files won't have a corresponding .m file, example bridging header
continue
if m_file_path in all_m_files:
all_m_files.remove(m_file_path)
else:
print(f"Path exists: {m_file_path} but was not found in 'all_m_files'")
h_file_text = read_text_from_file(h_file_path)
m_file_text = read_text_from_file(m_file_path)
# For each method definition in the .h file, add it to the method_map
for match in re.finditer(
r"^([+-].*?);$", h_file_text, flags=re.MULTILINE | re.DOTALL
):
method_declaration = match.group(1)
method_declaration = re.sub(
r"\n(NS_SWIFT_NAME|NS_SWIFT_UNAVAILABLE).*", "", method_declaration
)
key_for_method = key_for_method_declaration(method_declaration)
if method_declaration != key_for_method:
method_map[key_for_method] = method_declaration
# For each method definition in the .m file, replace it from the method_map
updated_m_file_text = m_file_text
for match in re.finditer(
r"^([+-].*?)\n{$", m_file_text, flags=re.MULTILINE | re.DOTALL
):
method_declaration = match.group(1)
key_for_method = key_for_method_declaration(method_declaration)
if key_for_method in method_map:
method_definition_from_header = method_map[key_for_method]
updated_m_file_text = updated_m_file_text.replace(
method_declaration, method_definition_from_header
)
if updated_m_file_text != m_file_text:
write_text_to_file(updated_m_file_text, m_file_path)
# unmatched_m_files = [f for f in all_m_files if not f.endswith("Tests.m")]
# unmatched_m_files.remove("FBSDKCoreKit/FBSDKCoreKitTests/Internal/AppEvents/ViewHierarchy/ObjCTestObject.m")
# print(unmatched_m_files)
def read_text_from_file(file: str) -> str:
with open(file, "r") as f:
text = f.read()
return text
def write_text_to_file(text: str, file: str) -> None:
with open(file, "w") as f:
f.write(text)
def key_for_method_declaration(method_declaration: str) -> str:
# Remove nullability keywords
key = method_declaration.replace("nullable ", "").replace(" _Nullable ", " ")
# Remove newlines
key = re.sub(r"\s*\n\s*", " ", key, flags=re.MULTILINE | re.DOTALL)
return key
def determine_base_dir() -> str:
base_dir = get_output("git rev-parse --show-toplevel")
if not base_dir:
this_file_dir = os.path.dirname(os.path.realpath(__file__))
base_dir = str(pathlib.Path(this_file_dir).parent.absolute())
return base_dir
def get_files_with_extension(extension: str) -> List[str]:
files_str = get_output(f"git ls-files '{extension}'")
if not files_str:
files_str = get_output(f"hg files -I '**/{extension}'")
if not files_str:
print(f"No files found with extension: {extension}", file=sys.stderr)
sys.exit(1)
files = files_str.splitlines()
filtered_files = [f for f in files if not f.startswith(("samples", "testing"))]
return filtered_files
def write_lines_to_file(lines: List[str], file: str) -> None:
with open(file, "w") as f:
f.writelines(lines)
def get_output(command: str) -> Optional[str]:
"""Returns the output of a shell command, or None if it fails"""
completed_process = subprocess.run(
command, shell=True, check=False, capture_output=True
)
if completed_process.returncode != 0:
# Uncomment for debugging
# print(f"Failed: {command}"\nSTDERR: {completed_process.stderr.decode()}")
return None
return completed_process.stdout.decode().rstrip()
if __name__ == "__main__":
main()