scripts/pyre_check.py (41 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import io import subprocess IGNORE_PATTERNS = [ "Missing attribute annotation [4]", "Missing global annotation [5]", # expected `List[str]` but got `List[_ScalarV]` # `google.protobuf.internal.containers._ScalarV` has no attribute `__add__` "_ScalarV", # Invalid type parameters [24]: Generic type `_Mapping` expects 2 type parameters. "_pb2.pyi:", "Annotation `pa.Array` is not defined as a type", "Annotation `pa.Field` is not defined as a type", "Annotation `pa.DataType` is not defined as a type", "Annotation `pa.RecordBatch` is not defined as a type", "Annotation `pa.Scalar` is not defined as a type", "Annotation `parquet.FileMetaData` is not defined as a type", "Annotation `faiss.Index` is not defined as a type.", "Annotation `struct_pb2.Struct` is not defined as a type.", "Undefined attribute [16]: Module `pyarrow` has no attribute", "Undefined attribute [16]: Module `pyarrow.compute` has no attribute", "Undefined attribute [16]: Module `pyarrow.csv` has no attribute", # type-safety of torch.nn.Module instances # https://github.com/pytorch/pytorch/issues/81462 # Call error [29]: `typing.Union[nn.modules.module.Module, torch._tensor.Tensor]` is # not a function. "Union[nn.modules.module.Module, torch._tensor.Tensor]", "Union[torch._tensor.Tensor, torch.nn.modules.module.Module]", "Union[torch._tensor.Tensor, nn.modules.module.Module]", "Union[Module, Tensor]", ] if __name__ == "__main__": result = subprocess.run(["pyre", "check"], stdout=subprocess.PIPE) errors = io.StringIO(result.stdout.decode("utf-8")) count = 0 for line in errors.readlines(): ignore = False for pattern in IGNORE_PATTERNS: if pattern in line: ignore = True if ignore: continue count += 1 print(line) if count > 0: print(f"Found {count} critical type errors.") exit(1) else: print("Found no critical type errors.")