parquet_cli/audit_tool/audit_tool.py (265 lines of code) (raw):
########################################################################################################################
# Description: This script is used to audit the parquet files in S3 against records OpenSearch. It will print out the #
# S3 keys of the files that are not found in OpenSearch. #
# #
# Usage: #
# python audit_tool.py [ -o <output_filename> ] #
# #
# Environment variables: #
# AWS_ACCESS_KEY_ID: AWS access key ID #
# AWS_SECRET_ACCESS_KEY: AWS secret access key #
# OPENSEARCH_ENDPOINT: OpenSearch endpoint #
# OPENSEARCH_PORT: OpenSearch port #
# OPENSEARCH_ID_PREFIX: OpenSearch ID prefix e.g. s3://cdms-dev-in-situ-parquet #
# OPENSEARCH_INDEX: OpenSearch index e.g. [ parquet_stats_alias | entry_file_records_alias ] #
# OPENSEARCH_PATH_PREFIX: OpenSearch path prefix (use '' if no prefix needed) #
# OPENSEARCH_BUCKET: OpenSearch bucket #
########################################################################################################################
import boto3
from requests_aws4auth import AWS4Auth
from elasticsearch import Elasticsearch, RequestsHttpConnection, NotFoundError
import os
import sys
import argparse
import textwrap
import json
import logging
from tempfile import NamedTemporaryFile
from datetime import timedelta, datetime, timezone
from parquet_flask.aws import AwsSQS, AwsSNS
from parquet_flask.cdms_lambda_func.lambda_logger_generator import LambdaLoggerGenerator
# logging.basicConfig(
# level=logging.INFO,
# format=
# )
LambdaLoggerGenerator.remove_default_handlers()
logger = LambdaLoggerGenerator.get_logger(
__name__,
log_format='%(asctime)s [%(levelname)s] [%(name)s::%(lineno)d] %(message)s'
)
LambdaLoggerGenerator.get_logger('elasticsearch', log_level=logging.WARNING)
PHASES = dict(
start=0,
list=1,
audit=2
)
# Append a slash to the end of a string if it doesn't already have one
def append_slash(string: str):
if string is None:
return None
elif string == '':
return string
elif string[-1] != '/':
return string + '/'
else:
return string
def key_to_sqs_msg(key: str, bucket: str):
s3_event = {
'Records': [
{
'eventName': 'ObjectCreated:Put',
's3': {
'bucket': {'name': bucket},
'object': {'key': key}
}
}
]
}
sqs_body = json.dumps(s3_event)
return sqs_body
def reinvoke(state, bucket, s3_client, lambda_client, function_name):
state['lastListTime'] = state['lastListTime'].strftime("%Y-%m-%dT%H:%M:%S%z")
logger.info('Preparing to reinvoke. Persisting audit state to S3')
object_data = json.dumps(state).encode('utf-8')
s3_client.put_object(Bucket=bucket, Key='AUDIT_STATE.json', Body=object_data)
response = lambda_client.invoke(
FunctionName=function_name,
InvocationType='Event',
Payload=json.dumps({"State": {"Bucket": bucket, "Key": "AUDIT_STATE.json"}})
)
logger.info(f'Lambda response: {repr(response)}')
def audit(format='plain', output_file=None, sns_topic=None, state=None, lambda_ctx=None):
if format == 'mock-s3' and not output_file:
raise ValueError('Output file MUST be defined with mock-s3 output format')
if state is None:
state = {}
# Check if AWS credentials are set
AWS_ACCESS_KEY_ID = os.environ.get('AWS_ACCESS_KEY_ID', None)
AWS_SECRET_ACCESS_KEY = os.environ.get('AWS_SECRET_ACCESS_KEY', None)
AWS_SESSION_TOKEN = os.environ.get('AWS_SESSION_TOKEN', None)
if AWS_ACCESS_KEY_ID is None or AWS_SECRET_ACCESS_KEY is None:
logger.error('AWS credentials are not set. Please set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY.')
exit(1)
# Check if OpenSearch parameters are set
OPENSEARCH_ENDPOINT = os.environ.get('OPENSEARCH_ENDPOINT', None)
OPENSEARCH_PORT = os.environ.get('OPENSEARCH_PORT', 443)
OPENSEARCH_INDEX = os.environ.get('OPENSEARCH_INDEX', None)
OPENSEARCH_PATH_PREFIX = append_slash(os.environ.get('OPENSEARCH_PATH_PREFIX', None))
OPENSEARCH_BUCKET = os.environ.get('OPENSEARCH_BUCKET', None)
OPENSEARCH_ID_PREFIX = append_slash(os.environ.get('OPENSEARCH_ID_PREFIX', f's3://{OPENSEARCH_BUCKET}/'))
if OPENSEARCH_ENDPOINT is None or OPENSEARCH_PORT is None or OPENSEARCH_ID_PREFIX is None or OPENSEARCH_INDEX is None or OPENSEARCH_PATH_PREFIX is None or OPENSEARCH_BUCKET is None:
logger.error('OpenSearch parameters are not set. Please set OPENSEARCH_ENDPOINT, OPENSEARCH_PORT, OPENSEARCH_ID_PREFIX, OPENSEARCH_INDEX, OPENSEARCH_PATH_PREFIX, OPENSEARCH_BUCKET.')
exit(1)
# AWS session
aws_session_param = {}
aws_session_param['aws_access_key_id'] = AWS_ACCESS_KEY_ID
aws_session_param['aws_secret_access_key'] = AWS_SECRET_ACCESS_KEY
if AWS_SESSION_TOKEN:
aws_session_param['aws_session_token'] = AWS_SESSION_TOKEN
aws_session = boto3.Session(**aws_session_param)
# AWS auth
aws_auth = AWS4Auth(
aws_session.get_credentials().access_key,
aws_session.get_credentials().secret_key,
aws_session.region_name,
'es',
session_token=aws_session.get_credentials().token
)
# S3 paginator
s3 = aws_session.client('s3')
lambda_client = aws_session.client('lambda')
phase = state.get('state', 'start')
phase = PHASES[phase]
marker = state.get('marker')
keys: list = state.get('keys', [])
opensearch_client = Elasticsearch(
hosts=[{'host': OPENSEARCH_ENDPOINT, 'port': OPENSEARCH_PORT}],
http_auth=aws_auth,
use_ssl=True,
verify_certs=True,
connection_class=RequestsHttpConnection
)
# Go through all files in a bucket
count = 0
error_count = 0
error_s3_keys = []
missing_keys = []
logger.info('processing... will print out S3 keys that cannot find a match...')
if phase < 2:
if phase == 0:
logger.info(f'Starting listing of bucket {OPENSEARCH_BUCKET}')
state['listStartTime'] = datetime.now(timezone.utc)
else:
logger.info(f'Resuming listing of bucket {OPENSEARCH_BUCKET}')
logger.info(f'Listing objects older than {state["lastListTime"].strftime("%Y-%m-%dT%H:%M:%S%z")}')
while True:
list_kwargs = dict(
Bucket=OPENSEARCH_BUCKET,
Prefix=OPENSEARCH_PATH_PREFIX,
MaxKeys=1000
)
if marker:
list_kwargs['ContinuationToken'] = marker
page = s3.list_objects_v2(**list_kwargs)
keys_to_add = []
for key in page.get('Contents', []):
if key['Key'].endswith('parquet') and key['LastModified'] >= state['lastListTime']:
keys_to_add.append(key['Key'])
keys.extend(keys_to_add)
logger.info(f"Listed page of {len(page.get('Contents', [])):,} objects; selected {len(keys_to_add):,}; "
f"total={len(keys):,}")
if lambda_ctx:
remaining_time = timedelta(milliseconds=lambda_ctx.get_remaining_time_in_millis())
logger.info(f'Remaining time: {remaining_time}')
if not page['IsTruncated']:
break
else:
marker = page['NextContinuationToken']
if lambda_ctx is not None and lambda_ctx.get_remaining_time_in_millis() < (60 * 1000):
logger.warning('Lambda is about to time out, re-invoking to resume from this key')
state = dict(
state='list',
marker=marker,
keys=keys,
listStartTime=state['lastListTime']
)
reinvoke(state, OPENSEARCH_BUCKET, s3, lambda_client, lambda_ctx.function_name)
return
state['lastListTime'] = state['listStartTime']
del state['listStartTime']
# if phase == 3 and marker is not None and marker in keys:
# logger.info(f'Resuming audit from key {marker}')
# index = keys.index(marker)
# else:
# logger.info('Starting audit from the beginning')
# index = 0
#
# keys = keys[index:]
n_keys = len(keys)
logger.info(f'Beginning audit on {n_keys:,} keys...')
need_to_resume = False
while len(keys) > 0:
key = keys.pop(0)
count += 1
try:
# Search key in opensearch
opensearch_id = os.path.join(OPENSEARCH_ID_PREFIX + key)
opensearch_response = opensearch_client.get(index=OPENSEARCH_INDEX, id=opensearch_id)
if opensearch_response is None or not type(opensearch_response) is dict or not opensearch_response['found']:
error_count += 1
error_s3_keys.append(key)
sys.stdout.write("\x1b[2k")
logger.info(key)
missing_keys.append(key)
except NotFoundError as e:
error_count += 1
error_s3_keys.append(key)
sys.stdout.write("\x1b[2k")
logger.info(key)
missing_keys.append(key)
except Exception as e:
error_count += 1
if count % 50 == 0:
logger.info(f'Checked {count} files [{(count/n_keys)*100:7.3f}%]')
if lambda_ctx:
remaining_time = timedelta(milliseconds=lambda_ctx.get_remaining_time_in_millis())
logger.info(f'Remaining time: {remaining_time}')
if lambda_ctx is not None and lambda_ctx.get_remaining_time_in_millis() < (60 * 1000):
logger.warning('Lambda is about to time out, re-invoking to resume from this key')
state = dict(
state='audit',
marker=key,
keys=keys,
lastListTime=state['lastListTime']
)
need_to_resume = True
break
logger.info(f'Checked {count} files')
logger.info(f'Found {len(missing_keys):,} missing keys')
if len(missing_keys) > 0:
if format == 'plain':
if output_file:
with open(output_file, 'w') as f:
for key in missing_keys:
f.write(key + '\n')
else:
logger.info('Not writing to file as none was given')
else:
sqs_messages = [key_to_sqs_msg(k, OPENSEARCH_BUCKET) for k in missing_keys]
sqs = AwsSQS()
sqs_response = None
for m in sqs_messages:
sqs_response = sqs.send_message(output_file, m)
logger.info(f'SQS response: {repr(sqs_response)}')
if sns_topic:
sns = AwsSNS()
sns_response = sns.publish(
sns_topic,
f'Parquet stats audit found {len(missing_keys):,} missing keys. Trying to republish to SQS.',
'Insitu audit'
)
logger.info(f'SNS response: {repr(sns_response)}')
if need_to_resume:
reinvoke(state, OPENSEARCH_BUCKET, s3, lambda_client, lambda_ctx.function_name)
return
# Finished, reset state to just last list time
logger.info('Audit complete! Persisting state to S3')
state = {'lastListTime': state['lastListTime'].strftime("%Y-%m-%dT%H:%M:%S%z")}
object_data = json.dumps(state).encode('utf-8')
s3.put_object(Bucket=OPENSEARCH_BUCKET, Key='AUDIT_STATE.json', Body=object_data)
if __name__ == '__main__':
# Parse arguments
parser = argparse.ArgumentParser(
description='Audit parquet files in S3 against records in OpenSearch',
epilog=textwrap.dedent('''\
Environment variables: (describe what they are for & provide examples where appropriate
AWS_ACCESS_KEY_ID : AWS access key ID for S3 bucket & OpenSearch index
AWS_SECRET_ACCESS_KEY : AWS secret access key for S3 bucket & OpenSearch index
AWS_REGION : AWS region for S3 bucket & OpenSearch index
OPENSEARCH_ENDPOINT : Endpoint for OpenSearch domain
OPENSEARCH_PORT : Port to connect to OpenSearch (Default: 443)
OPENSEARCH_BUCKET : Name of the bucket storing ingested Parquet files.
OPENSEARCH_PATH_PREFIX : Key prefix for objects in OPENSEARCH_BUCKET to audit.
OPENSEARCH_ID_PREFIX : S3 URI prefix for the id field in OpenSearch documents. Defaults to 's3://<OPENSEARCH_BUCKET>/'
OPENSEARCH_INDEX : OpenSearch index to audit
'''),
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument(
'-o', '--output-file',
nargs='?',
type=str,
help='file to output the S3 keys of the files that are not found in OpenSearch',
dest='output'
)
parser.add_argument(
'-f', '--format',
choices=['plain', 'mock-s3'],
default='plain',
dest='format',
help='Output format. \'plain\' will output keys of missing parquet files to the output file in plain text. '
'\'mock-s3\' will output missing keys to SQS (-o is required as the SQS queue URL), formatted as an S3 '
'object created event'
)
def utcfromisoformat(s):
return datetime.fromisoformat(s).astimezone(timezone.utc)
parser.add_argument(
'--from-time',
type=utcfromisoformat,
default=datetime(1970, 1, 1, tzinfo=timezone.utc),
dest='llt',
help='Check all objects newer than this time as ISO datetime string. Default: 1970-01-01'
)
args = parser.parse_args()
audit(args.format, args.output, state=dict(lastListTime=args.llt))