scripts/merge_waf_checklists.py (152 lines of code) (raw):

################################################################################# # # This script tries to merge different WAF checklists (WAF, review, svc guides). # It tries to find duplicates calculating the distance between strings with embeddings. # # Last updated: June 2024 # ################################################################################# import json import argparse import sys # Get input arguments parser = argparse.ArgumentParser(description='Merge different WAF checklists and removes duplicates') parser.add_argument('--review-checklist-file', dest='review_checklist_file', action='store', help='You need to supply the name of the JSON file with the review checklist to be merged (default: None)') parser.add_argument('--aprl-checklist-file', dest='aprl_checklist_file', action='store', help='You need to supply the name of the JSON file with the APRL checklist to be merged (default: None)') parser.add_argument('--sg-checklist-file', dest='sg_checklist_file', action='store', help='You need to supply the name of the JSON file with the Service Guide checklist to be merged (default: None)') parser.add_argument('--output-file', dest='output_file', action='store', help='The resulting checklist will be stored here (default: None)') parser.add_argument('--service-dictionary', dest='service_dictionary', action='store', help='JSON file with dictionary to map services to standard names and to ARM services') parser.add_argument('--calculate-embeddings', dest='calculate_embeddings', action='store_true', default=False, help='Whether embeddings and reco mappings will be calculated (default: False)') parser.add_argument('--save-embeddings', dest='save_embeddings', action='store_true', default=False, help='Whether calculated embeddings will be stored in the provided files (default: False)') parser.add_argument('--max-recos', dest='max_recos', action='store', type=int, default=0, help='You can optionally define a maximum of recos to process for embeddings. If 0 (default), no limit is set.') parser.add_argument('--verbose', dest='verbose', action='store_true', default=False, help='Run in verbose mode (default: False)') args = parser.parse_args() # Only import module if we are going to use it if (args.calculate_embeddings): from sentence_transformers import SentenceTransformer, util # Function to load a checklist stored in a JSON file def load_json_file(filename): try: with open(filename) as f: checklist = json.load(f) except Exception as e: print("ERROR: Error when loading review checklist JSON file", filename, "-", str(e)) sys.exit(1) if 'items' in checklist: if args.verbose: print('DEBUG: {0} recos loaded from {1}'.format(len(checklist['items']), filename)) else: print('ERROR: checklist in file {0} does not have an "items" element.'.format(filename)) sys.exit(1) return checklist # Dump JSON object to file def dump_json_file(json_object, filename): if args.verbose: print("DEBUG: Dumping JSON object to file", filename) json_string = json.dumps(json_object, sort_keys=True, ensure_ascii=False, indent=4, separators=(',', ': ')) with open(filename, 'w', encoding='utf-8') as f: f.write(json_string) f.close() # Function to calculate text embeddings # See https://stackoverflow.com/questions/65199011/is-there-a-way-to-check-similarity-between-two-full-sentences-in-python def calculate_embeddings(checklist, model): if (args.verbose): print('DEBUG: Calculating embeddings for checklist ({0} recos)...'.format(len(checklist['items']))) counter = 0 for reco in checklist['items']: counter += 1 if not ('embeddings' in reco): if (counter % 100 == 0): if (args.verbose): print('DEBUG: {0} recos processed'.format(counter)) if 'text' in reco: embeddings = model.encode(reco['text']) reco['embeddings'] = embeddings # if args.verbose: print('DEBUG: calculated embeddings for {0}: {1}'.format(reco['text'], str(embeddings))) else: if args.verbose: print('DEBUG: Missing "text" tag in recommendation') # texts = [x['text'] for x in checklist['items']] # text_embeddings = model.encode(texts) return checklist # Verify that text and embeddings are present in the items of a checklist def verify_checklist(checklist): items_count = len(checklist['items']) items_with_text_count = len([x for x in checklist['items'] if 'text' in x]) items_with_embeddings_count = len([x for x in checklist['items'] if 'embeddings' in x]) if (args.verbose): print('DEBUG: checklist analysis: {0} elements in total, {1} elements with "text" key, {2} elements with "embeddings" key.'.format(items_count, items_with_text_count, items_with_embeddings_count)) # Get the standard service name from the service dictionary def get_standard_service_name(service_name, service_dictionary=None): svc_match_found = False for svc in service_dictionary: if service_name in svc['names']: svc_match_found = True return svc['service'] if not svc_match_found: # if args.verbose: # print('DEBUG: Service not found in service dictionary:', service_name) return service_name # Get the standard WAF pillar name (Title case) def get_standard_waf_pillar_name(waf_pillar_name): if waf_pillar_name.lower() in ('reliability', 'resiliency'): return 'Reliability' elif waf_pillar_name.lower() in ('cost', 'cost optimization', 'cost efficiency'): return 'Cost' if waf_pillar_name.lower() in ('performance', 'scalability'): return 'Performance' if waf_pillar_name.lower() in ('operations', 'operational excellence'): return 'Operations' if waf_pillar_name.lower() in ('security'): return 'Reliability' else: return waf_pillar_name.title() ############### # Begin # ############### # Load the checklists review_checklist = load_json_file(args.review_checklist_file) aprl_checklist = load_json_file(args.aprl_checklist_file) sg_checklist = load_json_file(args.sg_checklist_file) # Calculate the embeddings for each reco and the closest reco in another checklist if args.calculate_embeddings: model = SentenceTransformer('distilbert-base-nli-mean-tokens') # model = SentenceTransformer("all-MiniLM-L6-v2") review_checklist = calculate_embeddings(review_checklist, model) aprl_checklist = calculate_embeddings(aprl_checklist, model) sg_checklist = calculate_embeddings(sg_checklist, model) # Verify that we have all we need verify_checklist(review_checklist) verify_checklist(aprl_checklist) verify_checklist(sg_checklist) # For every reco of the WAF service guide checklist, try to find the one in the others which is closest sg_reco_count = 0 for sg_reco in sg_checklist['items']: # It would be more efficient only running the distance algorithm in the recos matching service and WAF pillar, # but especially the service might not match ('Azure Kubernetes Service' vs 'AKS', 'Reliability' vs 'Resiliency', etc) sg_reco_count += 1 if (sg_reco_count <= args.max_recos) or (args.max_recos == 0): if 'embeddings' in sg_reco: min_distance = 100 matching_reco = None for review_reco in review_checklist['items']: if 'embeddings' in review_reco: this_distance = util.pytorch_cos_sim(sg_reco['embeddings'], review_reco['embeddings']) if this_distance < min_distance: min_distance = this_distance matching_reco = review_reco else: print('ERROR: Embeddings missing from review reco') if min_distance < 0.05: if (args.verbose): print('DEBUG: Match with distance {0}'.format(min_distance)) print('DEBUG: SG reco : {0}'.format(sg_reco['text'])) print('DEBUG: Review reco: {0}'.format(matching_reco['text'])) else: print('ERROR: Embeddings missing from SG reco') else: break if (sg_reco_count > args.max_recos) and (args.max_recos > 0): if (args.verbose): print('DEBUG: maximum number of recos provided ({0}) reached'.format(args.max_recos)) # Merge all three checklists in one full_checklist = review_checklist full_checklist['items'] += aprl_checklist['items'] full_checklist['items'] += sg_checklist['items'] # Standardize the service names with the service dictionary (if one provided) if args.service_dictionary: service_dictionary = None try: if args.verbose: print("DEBUG: Loading service dictionary from", args.service_dictionary) with open(args.service_dictionary) as f: service_dictionary = json.load(f) if args.verbose: print("DEBUG: service dictionary loaded successfully with {0} elements".format(len(service_dictionary.keys()))) except Exception as e: print("ERROR: Error when loading service dictionary from", args.service_dictionary, "-", str(e)) if service_dictionary: for item in full_checklist['items']: if 'service' in item: item['service'] = get_standard_service_name(item['service'], service_dictionary=service_dictionary) # Standardize the WAF pillar names for item in full_checklist['items']: if 'waf' in item: item['waf'] = get_standard_waf_pillar_name(item['waf']) # If an output file was specified, save the resulting checklist if args.output_file: dump_json_file(full_checklist, args.output_file) if (args.verbose): print('DEBUG: Merged checklist with {0} elements saved to {1}'.format(len(full_checklist['items']), args.output_file)) # If we want to save the results so that we don't calculate again if args.save_embeddings: dump_json_file(review_checklist, args.review_checklist_file) dump_json_file(aprl_checklist, args.aprl_checklist_file) dump_json_file(sg_checklist, args.sg_checklist_file)