def audit()

in parquet_cli/audit_tool/audit_tool.py [0:0]


def audit(format='plain', output_file=None, sns_topic=None, state=None, lambda_ctx=None):
    if format == 'mock-s3' and not output_file:
        raise ValueError('Output file MUST be defined with mock-s3 output format')

    if state is None:
        state = {}

    # Check if AWS credentials are set
    AWS_ACCESS_KEY_ID = os.environ.get('AWS_ACCESS_KEY_ID', None)
    AWS_SECRET_ACCESS_KEY = os.environ.get('AWS_SECRET_ACCESS_KEY', None)
    AWS_SESSION_TOKEN = os.environ.get('AWS_SESSION_TOKEN', None)
    if AWS_ACCESS_KEY_ID is None or AWS_SECRET_ACCESS_KEY is None:
        logger.error('AWS credentials are not set. Please set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY.')
        exit(1)

    # Check if OpenSearch parameters are set
    OPENSEARCH_ENDPOINT = os.environ.get('OPENSEARCH_ENDPOINT', None)
    OPENSEARCH_PORT = os.environ.get('OPENSEARCH_PORT', 443)
    OPENSEARCH_INDEX = os.environ.get('OPENSEARCH_INDEX', None)
    OPENSEARCH_PATH_PREFIX = append_slash(os.environ.get('OPENSEARCH_PATH_PREFIX', None))
    OPENSEARCH_BUCKET = os.environ.get('OPENSEARCH_BUCKET', None)
    OPENSEARCH_ID_PREFIX = append_slash(os.environ.get('OPENSEARCH_ID_PREFIX', f's3://{OPENSEARCH_BUCKET}/'))
    if OPENSEARCH_ENDPOINT is None or OPENSEARCH_PORT is None or OPENSEARCH_ID_PREFIX is None or OPENSEARCH_INDEX is None or OPENSEARCH_PATH_PREFIX is None or OPENSEARCH_BUCKET is None:
        logger.error('OpenSearch parameters are not set. Please set OPENSEARCH_ENDPOINT, OPENSEARCH_PORT, OPENSEARCH_ID_PREFIX, OPENSEARCH_INDEX, OPENSEARCH_PATH_PREFIX, OPENSEARCH_BUCKET.')
        exit(1)

    # AWS session
    aws_session_param = {}
    aws_session_param['aws_access_key_id'] = AWS_ACCESS_KEY_ID
    aws_session_param['aws_secret_access_key'] = AWS_SECRET_ACCESS_KEY

    if AWS_SESSION_TOKEN:
        aws_session_param['aws_session_token'] = AWS_SESSION_TOKEN

    aws_session = boto3.Session(**aws_session_param)

    # AWS auth
    aws_auth = AWS4Auth(
        aws_session.get_credentials().access_key,
        aws_session.get_credentials().secret_key,
        aws_session.region_name,
        'es',
        session_token=aws_session.get_credentials().token
    )

    # S3 paginator
    s3 = aws_session.client('s3')
    lambda_client = aws_session.client('lambda')

    phase = state.get('state', 'start')
    phase = PHASES[phase]
    marker = state.get('marker')
    keys: list = state.get('keys', [])

    opensearch_client = Elasticsearch(
        hosts=[{'host': OPENSEARCH_ENDPOINT, 'port': OPENSEARCH_PORT}],
        http_auth=aws_auth,
        use_ssl=True,
        verify_certs=True,
        connection_class=RequestsHttpConnection
    )

    # Go through all files in a bucket
    count = 0
    error_count = 0
    error_s3_keys = []
    missing_keys = []
    logger.info('processing... will print out S3 keys that cannot find a match...')
    if phase < 2:
        if phase == 0:
            logger.info(f'Starting listing of bucket {OPENSEARCH_BUCKET}')
            state['listStartTime'] = datetime.now(timezone.utc)
        else:
            logger.info(f'Resuming listing of bucket {OPENSEARCH_BUCKET}')

        logger.info(f'Listing objects older than {state["lastListTime"].strftime("%Y-%m-%dT%H:%M:%S%z")}')

        while True:
            list_kwargs = dict(
                Bucket=OPENSEARCH_BUCKET,
                Prefix=OPENSEARCH_PATH_PREFIX,
                MaxKeys=1000
            )

            if marker:
                list_kwargs['ContinuationToken'] = marker

            page = s3.list_objects_v2(**list_kwargs)
            keys_to_add = []

            for key in page.get('Contents', []):
                if key['Key'].endswith('parquet') and key['LastModified'] >= state['lastListTime']:
                    keys_to_add.append(key['Key'])

            keys.extend(keys_to_add)

            logger.info(f"Listed page of {len(page.get('Contents', [])):,} objects; selected {len(keys_to_add):,}; "
                        f"total={len(keys):,}")

            if lambda_ctx:
                remaining_time = timedelta(milliseconds=lambda_ctx.get_remaining_time_in_millis())
                logger.info(f'Remaining time: {remaining_time}')

            if not page['IsTruncated']:
                break
            else:
                marker = page['NextContinuationToken']

            if lambda_ctx is not None and lambda_ctx.get_remaining_time_in_millis() < (60 * 1000):
                logger.warning('Lambda is about to time out, re-invoking to resume from this key')

                state = dict(
                    state='list',
                    marker=marker,
                    keys=keys,
                    listStartTime=state['lastListTime']
                )

                reinvoke(state, OPENSEARCH_BUCKET, s3, lambda_client, lambda_ctx.function_name)
                return

        state['lastListTime'] = state['listStartTime']
        del state['listStartTime']

    # if phase == 3 and marker is not None and marker in keys:
    #     logger.info(f'Resuming audit from key {marker}')
    #     index = keys.index(marker)
    # else:
    #     logger.info('Starting audit from the beginning')
    #     index = 0
    #
    # keys = keys[index:]

    n_keys = len(keys)

    logger.info(f'Beginning audit on {n_keys:,} keys...')

    need_to_resume = False

    while len(keys) > 0:
        key = keys.pop(0)
        count += 1
        try:
            # Search key in opensearch
            opensearch_id = os.path.join(OPENSEARCH_ID_PREFIX + key)
            opensearch_response = opensearch_client.get(index=OPENSEARCH_INDEX, id=opensearch_id)
            if opensearch_response is None or not type(opensearch_response) is dict or not opensearch_response['found']:
                error_count += 1
                error_s3_keys.append(key)
                sys.stdout.write("\x1b[2k")
                logger.info(key)
                missing_keys.append(key)
        except NotFoundError as e:
            error_count += 1
            error_s3_keys.append(key)
            sys.stdout.write("\x1b[2k")
            logger.info(key)
            missing_keys.append(key)
        except Exception as e:
            error_count += 1

        if count % 50 == 0:
            logger.info(f'Checked {count} files [{(count/n_keys)*100:7.3f}%]')
            if lambda_ctx:
                remaining_time = timedelta(milliseconds=lambda_ctx.get_remaining_time_in_millis())
                logger.info(f'Remaining time: {remaining_time}')

        if lambda_ctx is not None and lambda_ctx.get_remaining_time_in_millis() < (60 * 1000):
            logger.warning('Lambda is about to time out, re-invoking to resume from this key')

            state = dict(
                state='audit',
                marker=key,
                keys=keys,
                lastListTime=state['lastListTime']
            )

            need_to_resume = True

            break

    logger.info(f'Checked {count} files')
    logger.info(f'Found {len(missing_keys):,} missing keys')

    if len(missing_keys) > 0:
        if format == 'plain':
            if output_file:
                with open(output_file, 'w') as f:
                    for key in missing_keys:
                        f.write(key + '\n')
            else:
                logger.info('Not writing to file as none was given')
        else:
            sqs_messages = [key_to_sqs_msg(k, OPENSEARCH_BUCKET) for k in missing_keys]

            sqs = AwsSQS()

            sqs_response = None

            for m in sqs_messages:
                sqs_response = sqs.send_message(output_file, m)

            logger.info(f'SQS response: {repr(sqs_response)}')

            if sns_topic:
                sns = AwsSNS()

                sns_response = sns.publish(
                    sns_topic,
                    f'Parquet stats audit found {len(missing_keys):,} missing keys. Trying to republish to SQS.',
                    'Insitu audit'
                )

                logger.info(f'SNS response: {repr(sns_response)}')

    if need_to_resume:
        reinvoke(state, OPENSEARCH_BUCKET, s3, lambda_client, lambda_ctx.function_name)
        return

    # Finished, reset state to just last list time

    logger.info('Audit complete! Persisting state to S3')

    state = {'lastListTime': state['lastListTime'].strftime("%Y-%m-%dT%H:%M:%S%z")}

    object_data = json.dumps(state).encode('utf-8')

    s3.put_object(Bucket=OPENSEARCH_BUCKET, Key='AUDIT_STATE.json', Body=object_data)