tools/scripts/codegen/protocol_tests_gen.py (234 lines of code) (raw):
#!/usr/bin/env python3
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0.
"""
This is a module to handle protocol tests generation.
"""
import json
import os
import pathlib
import re
import sys
from concurrent.futures import ProcessPoolExecutor, wait, FIRST_COMPLETED, ALL_COMPLETED
from codegen.legacy_c2j_cpp_gen import LegacyC2jCppGen
from codegen.model_utils import SERVICE_MODEL_FILENAME_PATTERN, ServiceModel, ModelUtils
PROTOCOL_TESTS_BASE_DIR = "tools/code-generation/protocol-tests"
PROTOCOL_TESTS_CLIENT_MODELS = PROTOCOL_TESTS_BASE_DIR + "/api-descriptions"
PROTOCOL_TESTS_ENDPOINT_RULES = "endpoint-rule-set.json" # Dummy endpoint ruleset
PROTOCOL_TESTS_DEFINITION_SETS = ["input", "output"]
PROTOCOL_TESTS_GENERATED_CLIENTS_DIR = "generated/protocol-tests/test-clients"
PROTOCOL_GENERATED_TESTS_DIR = "generated/protocol-tests/tests"
UNSUPPORTED_CLIENTS = {"rpcv2protocol" # RPC V2 CBOR support is not implemented on this SDK
}
UNSUPPORTED_TESTS = {"smithy-rpc-v2-cbor"}
# Regexp to parse C2J model filename to extract service name and date version
TEST_DEFINITION_FILENAME_PATTERN = re.compile(
"^"
"(?P<name>.+)"
".json$"
)
class ProtocolTestsGen(object):
"""A wrapper for Protocol tests generator for C++ SDK
"""
class ProtoTestC2jClientModelMetadata:
def __init__(self, filename: str, model_path: str, md: dict):
self.service_name = SERVICE_MODEL_FILENAME_PATTERN.match(filename).group("service")
self.model_path = model_path
self.md = md
class ProtocolTestModel:
def __init__(self, test_type: str, test_name: str, c2j_test_model: str, c2j_client_md):
self.test_type = test_type # ex: input our output
self.service_name = c2j_client_md.service_name
self.test_name = test_name
# File paths to model files
self.c2j_test_model = c2j_test_model
self.c2j_client_model = c2j_client_md.model_path
def __init__(self,
args: dict):
sdk_root_dir = pathlib.Path(__file__).parents[3]
self.debug = args.get("debug", False)
self.client_models_dir = str(pathlib.Path(f"{sdk_root_dir}/{PROTOCOL_TESTS_CLIENT_MODELS}").resolve())
self.test_definitions_dir = str(pathlib.Path(f"{sdk_root_dir}/{PROTOCOL_TESTS_BASE_DIR}").resolve())
self.generated_test_clients_dir = str(
pathlib.Path(f"{sdk_root_dir}/{PROTOCOL_TESTS_GENERATED_CLIENTS_DIR}").resolve())
self.generated_tests_dir = str(pathlib.Path(f"{sdk_root_dir}/{PROTOCOL_GENERATED_TESTS_DIR}").resolve())
self.c2j_client_generator = LegacyC2jCppGen(args, dict())
self.c2j_client_generator.path_to_api_definitions = self.client_models_dir
self.c2j_client_generator.path_to_endpoint_rules = str(
pathlib.Path(f"{sdk_root_dir}/{PROTOCOL_TESTS_BASE_DIR}").resolve())
self.c2j_client_generator.output_location = PROTOCOL_TESTS_GENERATED_CLIENTS_DIR
self.c2j_tests_generator = LegacyC2jCppGen(args, dict())
self.c2j_tests_generator.path_to_api_definitions = ""
self.c2j_tests_generator.path_to_endpoint_rules = ""
self.c2j_tests_generator.output_location = PROTOCOL_GENERATED_TESTS_DIR
def generate(self, executor: ProcessPoolExecutor, max_workers: int):
"""
Generate protocol tests (test clients and a corresponding set of tests)
:param executor:
:param max_workers:
:return:
"""
if self._generate_test_clients(executor, max_workers) == 0:
return self._generate_tests(executor, max_workers)
return -1
def _generate_test_clients(self, executor: ProcessPoolExecutor, max_workers: int):
self.c2j_client_generator.build_generator(self.c2j_client_generator.path_to_generator)
service_models = self._collect_test_client_models()
os.makedirs(self.generated_test_clients_dir, exist_ok=True)
pending = set()
done = set()
sys.stdout.flush()
for service, model_files in service_models.items():
while len(pending) >= max_workers:
new_done, pending = wait(pending, return_when=FIRST_COMPLETED)
done.update(new_done)
task = executor.submit(self._generate_test_client,
service,
model_files,
PROTOCOL_TESTS_GENERATED_CLIENTS_DIR)
pending.add(task)
new_done, _ = wait(pending, return_when=ALL_COMPLETED)
done.update(new_done)
failures = set()
for result in done:
try:
service, status = result.result() # will rethrow any exceptions
if status != 0:
raise RuntimeError(f"Protocol test client {service} (re)generation failed: {status}")
except Exception as exc:
failures.add(f"Protocol test client (re)generation failed with error.\n Exception: {exc}\n"
f"stderr: {getattr(exc, 'stderr', None)}")
if len(failures):
print(f"Code generation failed, processed {len(done)} packages. "
f"Encountered {len(failures)} failures:\n")
for failure in failures:
print(failure)
if len(failures):
return -1
return 0
def _generate_test_client(self,
service_name: str,
model_files: ServiceModel,
output_dir: str):
service_name, status = self.c2j_client_generator.generate_client(service_name, model_files, output_dir, None)
return service_name, status
def _collect_test_client_models(self) -> dict:
service_models = dict()
model_files = os.listdir(self.client_models_dir)
for filename in model_files:
if not os.path.isfile("/".join([str(self.client_models_dir), filename])):
continue
match = SERVICE_MODEL_FILENAME_PATTERN.match(filename)
service_model_name = match.group("service")
_ = match.group("date")
if service_model_name in UNSUPPORTED_CLIENTS:
print(f"Skipping protocol tests client generation: {filename}")
continue
use_smithy = ModelUtils.is_smithy_enabled(service_model_name, self.client_models_dir, filename)
service_models[service_model_name] = ServiceModel(service_model_name, filename,
PROTOCOL_TESTS_ENDPOINT_RULES, None, use_smithy)
return service_models
def _get_client_models_metadata(self) -> list:
models = list()
model_files = os.listdir(self.client_models_dir)
for filename in sorted(model_files):
if not os.path.isfile("/".join([self.client_models_dir, filename])):
continue
model_abspath = str(pathlib.Path(f"{self.client_models_dir}/{filename}").resolve())
with open(model_abspath, 'r') as file_content:
try:
c2j_model = json.load(file_content)
model_metadata = self.ProtoTestC2jClientModelMetadata(filename, model_abspath,
c2j_model.get("metadata"))
models.append(model_metadata)
except Exception as exc:
print(f"ERROR: unexpected file content in protocol tests clients dir {self.client_models_dir}. "
f"Expected c2j client model, but json metadata kew is missing: {exc}")
return models
def _collect_test_definition_models(self) -> dict:
all_test_clients_md = self._get_client_models_metadata()
test_models = dict() # ex: "{input: {ec2: ProtocolTestModel}, output: {ec2: ProtocolTestModel}}"
for test_def_group in PROTOCOL_TESTS_DEFINITION_SETS:
model_files = os.listdir(f"{self.test_definitions_dir}/{test_def_group}")
for filename in model_files:
if not os.path.isfile(f"{self.test_definitions_dir}/{test_def_group}/{filename}"):
continue
match = TEST_DEFINITION_FILENAME_PATTERN.match(filename)
test_def_name = match.group("name")
if test_def_name in UNSUPPORTED_TESTS:
print(f"Skipping protocol tests generation: {test_def_group}/{filename}")
continue
test_def_path = str(pathlib.Path(f"{self.test_definitions_dir}/{test_def_group}/{filename}").resolve())
def _get_corresponding_test_client(test_clients_md: list, test_path: str) -> list:
# Get c2j client models matching the test suite
# more than 1 is possible (ex: xml and xml with namespace clients for a single test suite)
result = list()
with open(test_path, 'r') as file_content:
try:
proto_test_model = json.load(file_content)
proto_test_md = proto_test_model[0].get("metadata")
for c2j_md in test_clients_md:
for field_to_match in ["apiVersion", "protocols", "jsonVersion", "targetPrefix"]:
if proto_test_md.get(field_to_match, None) != c2j_md.md.get(field_to_match, None):
break
else:
result.append(c2j_md)
except Exception as exc:
print(f"ERROR: unexpected file content in protocol tests {test_def_path}. "
f"Expected c2j protocol test, but json metadata kew is missing: {exc}")
return result
test_clients_for_suite = _get_corresponding_test_client(all_test_clients_md, test_def_path)
if test_clients_for_suite is None or len(test_clients_for_suite) == 0:
raise Exception(f"ERROR: Unable to find C2J client model for the test suite: {test_def_path}")
for index, client_md in enumerate(test_clients_for_suite):
if index == 0:
test_def_key = test_def_name
else:
test_def_key = f"{test_def_name}-{index}"
assert test_models.get(test_def_group, dict()).get(test_def_key, None) is None, \
f"This test suite {test_def_group}/{test_def_key} already exists: {test_models}"
if self.debug:
print("Protocol test generation task:\t"
f"{test_def_path.split('/')[-1]} with {client_md.model_path.split('/')[-1]}")
if test_def_group not in test_models:
test_models[test_def_group] = dict()
test_models[test_def_group][test_def_key] = self.ProtocolTestModel(test_type=test_def_group,
test_name=test_def_key,
c2j_test_model=test_def_path,
c2j_client_md=client_md)
return test_models
def _generate_tests(self, executor: ProcessPoolExecutor, max_workers: int):
test_models = self._collect_test_definition_models()
pending = set()
done = set()
sys.stdout.flush()
for test_def_group, test_suites in test_models.items():
os.makedirs(test_def_group, exist_ok=True)
for protocol, test_models in test_suites.items():
while len(pending) >= max_workers:
new_done, pending = wait(pending, return_when=FIRST_COMPLETED)
done.update(new_done)
task = executor.submit(self._generate_single_protocol_test,
test_models)
pending.add(task)
new_done, _ = wait(pending, return_when=ALL_COMPLETED)
done.update(new_done)
failures = set()
for result in done:
try:
service, status = result.result() # will rethrow any exceptions
if status != 0:
raise RuntimeError(f"Protocol test client {service} (re)generation failed: {status}")
except Exception as exc:
failures.add(f"Protocol test client (re)generation failed with error.\n Exception: {exc}\n"
f"stderr: {getattr(exc, 'stderr', None)}")
if len(failures):
print(f"Code generation failed, processed {len(done)} packages. "
f"Encountered {len(failures)} failures:\n")
for failure in failures:
print(failure)
if len(failures):
return -1
return 0
def _generate_single_protocol_test(self, models: ProtocolTestModel):
"""Call java generator to generate a single protocol test suite
:param test_group: ex: "input" or "output"
:param name: ex: "ec2", "json", "xml"
:param models: ProtocolTestModel
:return:
"""
generator_jar = self.c2j_tests_generator.path_to_generator + "/" + self.c2j_tests_generator.GENERATOR_JAR
run_command = list()
run_command.append("java")
run_command += ["-jar", generator_jar]
run_command += ["--inputfile", models.c2j_client_model]
run_command += ["--protocol-tests", models.c2j_test_model]
run_command += ["--protocol-tests-type", models.test_type]
run_command += ["--protocol-tests-name", models.test_name]
run_command += ["--service", models.service_name]
run_command += ["--outputfile", "STDOUT"]
run_command.append("--generate-tests")
name_for_logging = f"protocol test {models.test_type}/{models.test_name}"
output_zip_file = self.c2j_tests_generator.run_generator_once(name_for_logging,
run_command, "STDOUT")
dir_to_delete = f"{self.generated_tests_dir}/{models.test_type}/{models.test_name}"
dir_to_extract = f"{self.generated_tests_dir}/{models.test_type}"
name_for_logging, status = self.c2j_tests_generator.extract_zip(output_zip_file, name_for_logging,
dir_to_extract, dir_to_delete)
return name_for_logging, status