DataConnectors/AWS-SecurityHubFindings/AzFunAWSSecurityHubIngestion/__init__.py (334 lines of code) (raw):

import base64 import datetime import hashlib import hmac import json import logging import os import re import time from threading import Thread import azure.functions as func import boto3 import boto3.exceptions import dateutil import requests from azure.core.exceptions import ClientAuthenticationError from azure.identity import ( AzureCliCredential, ChainedTokenCredential, DefaultAzureCredential, ManagedIdentityCredential, ) client_id = os.environ.get("ClientID") sentinel_customer_id = os.environ.get("WorkspaceID") sentinel_shared_key = os.environ.get("WorkspaceKey") aws_role_arn = os.environ.get( "AWSRoleArn" ) # Should be full ARN, including AWS account number eg. arn:aws:iam::133761391337:role/AzureSentinelSyncRole aws_role_session_name = os.environ.get("AWSRoleSessionName") aws_region_name = os.environ.get("AWSRegionName") aws_securityhub_filters = os.environ.get("SecurityHubFilters") sentinel_log_type = os.environ.get("LogAnalyticsCustomLogName") fresh_event_timestamp = os.environ.get("FreshEventTimeStamp") logAnalyticsUri = os.environ.get("LAURI") if logAnalyticsUri in (None, "") or str(logAnalyticsUri).isspace(): logAnalyticsUri = "https://" + sentinel_customer_id + ".ods.opinsights.azure.com" pattern = r"https:\/\/([\w\-]+)\.ods\.opinsights\.azure.([a-zA-Z\.]+)$" match = re.match(pattern, str(logAnalyticsUri)) if not match: raise Exception("AWSSecurityHubFindingsDataconnector: Invalid Log Analytics Uri. Please update configuration.") if client_id in (None, "") or str(client_id).isspace(): raise Exception("AWSSecurityHubFindingsDataconnector: Missing Client ID. Please update configuration.") payload_fields = [ "SchemaVersion", "Id", "ProductArn", "GeneratorId", "AwsAccountId", "Types", "FirstObservedAt", "LastObservedAt", "UpdatedAt", "Title", "ProductFields", "ProductArn", "CreatedAt", "Resources", "WorkflowState", "RecordState", "Compliance", ] payload_json_fields = ["Severity", "ProductFields"] def main(mytimer: func.TimerRequest) -> None: if mytimer.past_due: logging.info("The timer is past due!") logging.info("Starting program") # auth to azure ad logging.info("Authenticating to Azure AD.") token = None try: managed_identity = ManagedIdentityCredential() azure_cli = AzureCliCredential() default_azure_credential = DefaultAzureCredential( exclude_shared_token_cache_credential=True ) credential_chain = ChainedTokenCredential( managed_identity, azure_cli, default_azure_credential ) token_meta = credential_chain.get_token(client_id) token = token_meta.token except ClientAuthenticationError as error: logging.info("Authenticating to Azure AD: %s" % error) sentinel = AzureSentinelConnector( logAnalyticsUri, sentinel_customer_id, sentinel_shared_key, sentinel_log_type, queue_size=10000, bulks_number=10, ) securityHubSession = SecurityHubClient( aws_role_arn, aws_role_session_name, aws_region_name, token ) securityhub_filters_dict = {} logging.info("SecurityHubFilters : {0}".format(aws_securityhub_filters)) if aws_securityhub_filters: securityhub_filters = aws_securityhub_filters.replace("'", '"') securityhub_filters_dict = eval(securityhub_filters) results = securityHubSession.getFindings(securityhub_filters_dict) fresh_events_after_this_time = securityHubSession.freshEventTimestampGenerator( int(fresh_event_timestamp) ) fresh_events = True first_call = True failed_sent_events_number = 0 successfull_sent_events_number = 0 while (first_call or "NextToken" in results) and fresh_events: # Loop through all findings (100 per page) returned by Security Hub API call # Break out of the loop when we have looked back across the last hour of events (based on the finding's LastObservedAt timestamp) first_call = False for finding in results["Findings"]: finding_timestamp = securityHubSession.findingTimestampGenerator( finding["LastObservedAt"] ) if finding_timestamp > fresh_events_after_this_time: logging.info("SecurityHub Finding:{0}".format(json.dumps(finding))) payload = {} for field in payload_fields: payload.update({field: finding.get(field, "N/A")}) for json_field in payload_json_fields: payload.update( { json_field: json.dumps( finding.get(json_field, "N/A"), sort_keys=True ) } ) with sentinel: sentinel.send(payload) failed_sent_events_number = sentinel.failed_sent_events_number successfull_sent_events_number = sentinel.successfull_sent_events_number else: fresh_events = False break if fresh_events and "NextToken" in results: results = securityHubSession.getFindingsWithToken( results["NextToken"], securityhub_filters_dict ) if failed_sent_events_number: logging.error("{} events have not been sent".format(failed_sent_events_number)) if successfull_sent_events_number: logging.info( "Program finished. {} events have been sent. {} events have not been sent".format( successfull_sent_events_number, failed_sent_events_number ) ) if successfull_sent_events_number == 0 and failed_sent_events_number == 0: logging.info("No Fresh SecurityHub Events") class SecurityHubClient: def __init__(self, aws_role_arn, aws_role_session_name, aws_region_name, token): # define input self.role_arn = aws_role_arn self.role_session_name = aws_role_session_name self.aws_region_name = aws_region_name self.web_identity_token = token # create an STS client object that represents a live connection to the STS service sts_client = boto3.client("sts") # call assume_role method using input + client assumed_role_object = None try: assumed_role_object = sts_client.assume_role_with_web_identity( RoleArn=self.role_arn, RoleSessionName=self.role_session_name, WebIdentityToken=self.web_identity_token, ) logging.info("Successfully assumed role with web identity.") except boto3.exceptions.Boto3Error as error: logging.info("Assuming role with web identity failed: %s" % error) # from the response, get credentials credentials = assumed_role_object["Credentials"] logging.info("AccessKeyId : {0}".format(credentials["AccessKeyId"])) logging.info( "AssumedRoleArn : {0}".format(assumed_role_object["AssumedRoleUser"]["Arn"]) ) # use temp creds to make connection self.securityhub = boto3.client( "securityhub", aws_access_key_id=credentials["AccessKeyId"], aws_secret_access_key=credentials["SecretAccessKey"], aws_session_token=credentials["SessionToken"], region_name=self.aws_region_name, ) def freshEventTimestampGenerator(self, freshEventsDuration): tm = datetime.datetime.utcfromtimestamp(time.time()) return time.mktime( (tm - datetime.timedelta(minutes=freshEventsDuration)).timetuple() ) # Gets the epoch time of a UTC timestamp in a Security Hub finding def findingTimestampGenerator(self, finding_time): d = dateutil.parser.parse(finding_time) d.astimezone(dateutil.tz.tzutc()) return time.mktime(d.timetuple()) # Gets 100 most recent findings from securityhub def getFindings(self, filters={}): return self.securityhub.get_findings( Filters=filters, MaxResults=100, SortCriteria=[{"Field": "LastObservedAt", "SortOrder": "desc"}], ) # Gets 100 findings from securityhub using the NextToken from a previous request def getFindingsWithToken(self, token, filters={}): return self.securityhub.get_findings( Filters=filters, NextToken=token, MaxResults=100, SortCriteria=[{"Field": "LastObservedAt", "SortOrder": "desc"}], ) class AzureSentinelConnector: def __init__( self, log_analytics_uri, customer_id, shared_key, log_type, queue_size=200, bulks_number=10, queue_size_bytes=25 * (2**20), ): self.log_analytics_uri = log_analytics_uri self.customer_id = customer_id self.shared_key = shared_key self.log_type = log_type self.queue_size = queue_size self.bulks_number = bulks_number self.queue_size_bytes = queue_size_bytes self._queue = [] self._bulks_list = [] self.successfull_sent_events_number = 0 self.failed_sent_events_number = 0 self.failedToSend = False def send(self, event): self._queue.append(event) if len(self._queue) >= self.queue_size: self.flush(force=False) def flush(self, force=True): self._bulks_list.append(self._queue) if force: self._flush_bulks() else: if len(self._bulks_list) >= self.bulks_number: self._flush_bulks() self._queue = [] def _flush_bulks(self): jobs = [] for queue in self._bulks_list: if queue: queue_list = self._split_big_request(queue) for q in queue_list: jobs.append( Thread( target=self._post_data, args=( self.customer_id, self.shared_key, q, self.log_type, ), ) ) for job in jobs: job.start() for job in jobs: job.join() self._bulks_list = [] def __enter__(self): pass def __exit__(self, type, value, traceback): self.flush() def _build_signature( self, customer_id, shared_key, date, content_length, method, content_type, resource, ): x_headers = "x-ms-date:" + date string_to_hash = ( method + "\n" + str(content_length) + "\n" + content_type + "\n" + x_headers + "\n" + resource ) bytes_to_hash = bytes(string_to_hash, encoding="utf-8") decoded_key = base64.b64decode(shared_key) encoded_hash = base64.b64encode( hmac.new(decoded_key, bytes_to_hash, digestmod=hashlib.sha256).digest() ).decode() authorization = "SharedKey {}:{}".format(customer_id, encoded_hash) return authorization def _post_data(self, customer_id, shared_key, body, log_type): events_number = len(body) body = json.dumps(body, sort_keys=True) method = "POST" content_type = "application/json" resource = "/api/logs" rfc1123date = datetime.datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S GMT") content_length = len(body) signature = self._build_signature( customer_id, shared_key, rfc1123date, content_length, method, content_type, resource, ) uri = self.log_analytics_uri + resource + "?api-version=2016-04-01" headers = { "content-type": content_type, "Authorization": signature, "Log-Type": log_type, "x-ms-date": rfc1123date, } response = requests.post(uri, data=body, headers=headers) if response.status_code >= 200 and response.status_code <= 299: self.successfull_sent_events_number += events_number self.failedToSend = False else: logging.error( "Error during sending events to Azure Sentinel. Response code: {}".format( response.status_code ) ) self.failed_sent_events_number += events_number self.failedToSend = True def _check_size(self, queue): data_bytes_len = len(json.dumps(queue).encode()) return data_bytes_len < self.queue_size_bytes def _split_big_request(self, queue): if self._check_size(queue): return [queue] else: middle = int(len(queue) / 2) queues_list = [queue[:middle], queue[middle:]] return self._split_big_request(queues_list[0]) + self._split_big_request( queues_list[1] )