DataConnectors/AWS-SecurityHubFindings/AzFunAWSSecurityHubIngestion/__init__.py (334 lines of code) (raw):
import base64
import datetime
import hashlib
import hmac
import json
import logging
import os
import re
import time
from threading import Thread
import azure.functions as func
import boto3
import boto3.exceptions
import dateutil
import requests
from azure.core.exceptions import ClientAuthenticationError
from azure.identity import (
AzureCliCredential,
ChainedTokenCredential,
DefaultAzureCredential,
ManagedIdentityCredential,
)
client_id = os.environ.get("ClientID")
sentinel_customer_id = os.environ.get("WorkspaceID")
sentinel_shared_key = os.environ.get("WorkspaceKey")
aws_role_arn = os.environ.get(
"AWSRoleArn"
) # Should be full ARN, including AWS account number eg. arn:aws:iam::133761391337:role/AzureSentinelSyncRole
aws_role_session_name = os.environ.get("AWSRoleSessionName")
aws_region_name = os.environ.get("AWSRegionName")
aws_securityhub_filters = os.environ.get("SecurityHubFilters")
sentinel_log_type = os.environ.get("LogAnalyticsCustomLogName")
fresh_event_timestamp = os.environ.get("FreshEventTimeStamp")
logAnalyticsUri = os.environ.get("LAURI")
if logAnalyticsUri in (None, "") or str(logAnalyticsUri).isspace():
logAnalyticsUri = "https://" + sentinel_customer_id + ".ods.opinsights.azure.com"
pattern = r"https:\/\/([\w\-]+)\.ods\.opinsights\.azure.([a-zA-Z\.]+)$"
match = re.match(pattern, str(logAnalyticsUri))
if not match:
raise Exception("AWSSecurityHubFindingsDataconnector: Invalid Log Analytics Uri. Please update configuration.")
if client_id in (None, "") or str(client_id).isspace():
raise Exception("AWSSecurityHubFindingsDataconnector: Missing Client ID. Please update configuration.")
payload_fields = [
"SchemaVersion",
"Id",
"ProductArn",
"GeneratorId",
"AwsAccountId",
"Types",
"FirstObservedAt",
"LastObservedAt",
"UpdatedAt",
"Title",
"ProductFields",
"ProductArn",
"CreatedAt",
"Resources",
"WorkflowState",
"RecordState",
"Compliance",
]
payload_json_fields = ["Severity", "ProductFields"]
def main(mytimer: func.TimerRequest) -> None:
if mytimer.past_due:
logging.info("The timer is past due!")
logging.info("Starting program")
# auth to azure ad
logging.info("Authenticating to Azure AD.")
token = None
try:
managed_identity = ManagedIdentityCredential()
azure_cli = AzureCliCredential()
default_azure_credential = DefaultAzureCredential(
exclude_shared_token_cache_credential=True
)
credential_chain = ChainedTokenCredential(
managed_identity, azure_cli, default_azure_credential
)
token_meta = credential_chain.get_token(client_id)
token = token_meta.token
except ClientAuthenticationError as error:
logging.info("Authenticating to Azure AD: %s" % error)
sentinel = AzureSentinelConnector(
logAnalyticsUri,
sentinel_customer_id,
sentinel_shared_key,
sentinel_log_type,
queue_size=10000,
bulks_number=10,
)
securityHubSession = SecurityHubClient(
aws_role_arn, aws_role_session_name, aws_region_name, token
)
securityhub_filters_dict = {}
logging.info("SecurityHubFilters : {0}".format(aws_securityhub_filters))
if aws_securityhub_filters:
securityhub_filters = aws_securityhub_filters.replace("'", '"')
securityhub_filters_dict = eval(securityhub_filters)
results = securityHubSession.getFindings(securityhub_filters_dict)
fresh_events_after_this_time = securityHubSession.freshEventTimestampGenerator(
int(fresh_event_timestamp)
)
fresh_events = True
first_call = True
failed_sent_events_number = 0
successfull_sent_events_number = 0
while (first_call or "NextToken" in results) and fresh_events:
# Loop through all findings (100 per page) returned by Security Hub API call
# Break out of the loop when we have looked back across the last hour of events (based on the finding's LastObservedAt timestamp)
first_call = False
for finding in results["Findings"]:
finding_timestamp = securityHubSession.findingTimestampGenerator(
finding["LastObservedAt"]
)
if finding_timestamp > fresh_events_after_this_time:
logging.info("SecurityHub Finding:{0}".format(json.dumps(finding)))
payload = {}
for field in payload_fields:
payload.update({field: finding.get(field, "N/A")})
for json_field in payload_json_fields:
payload.update(
{
json_field: json.dumps(
finding.get(json_field, "N/A"), sort_keys=True
)
}
)
with sentinel:
sentinel.send(payload)
failed_sent_events_number = sentinel.failed_sent_events_number
successfull_sent_events_number = sentinel.successfull_sent_events_number
else:
fresh_events = False
break
if fresh_events and "NextToken" in results:
results = securityHubSession.getFindingsWithToken(
results["NextToken"], securityhub_filters_dict
)
if failed_sent_events_number:
logging.error("{} events have not been sent".format(failed_sent_events_number))
if successfull_sent_events_number:
logging.info(
"Program finished. {} events have been sent. {} events have not been sent".format(
successfull_sent_events_number, failed_sent_events_number
)
)
if successfull_sent_events_number == 0 and failed_sent_events_number == 0:
logging.info("No Fresh SecurityHub Events")
class SecurityHubClient:
def __init__(self, aws_role_arn, aws_role_session_name, aws_region_name, token):
# define input
self.role_arn = aws_role_arn
self.role_session_name = aws_role_session_name
self.aws_region_name = aws_region_name
self.web_identity_token = token
# create an STS client object that represents a live connection to the STS service
sts_client = boto3.client("sts")
# call assume_role method using input + client
assumed_role_object = None
try:
assumed_role_object = sts_client.assume_role_with_web_identity(
RoleArn=self.role_arn,
RoleSessionName=self.role_session_name,
WebIdentityToken=self.web_identity_token,
)
logging.info("Successfully assumed role with web identity.")
except boto3.exceptions.Boto3Error as error:
logging.info("Assuming role with web identity failed: %s" % error)
# from the response, get credentials
credentials = assumed_role_object["Credentials"]
logging.info("AccessKeyId : {0}".format(credentials["AccessKeyId"]))
logging.info(
"AssumedRoleArn : {0}".format(assumed_role_object["AssumedRoleUser"]["Arn"])
)
# use temp creds to make connection
self.securityhub = boto3.client(
"securityhub",
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
region_name=self.aws_region_name,
)
def freshEventTimestampGenerator(self, freshEventsDuration):
tm = datetime.datetime.utcfromtimestamp(time.time())
return time.mktime(
(tm - datetime.timedelta(minutes=freshEventsDuration)).timetuple()
)
# Gets the epoch time of a UTC timestamp in a Security Hub finding
def findingTimestampGenerator(self, finding_time):
d = dateutil.parser.parse(finding_time)
d.astimezone(dateutil.tz.tzutc())
return time.mktime(d.timetuple())
# Gets 100 most recent findings from securityhub
def getFindings(self, filters={}):
return self.securityhub.get_findings(
Filters=filters,
MaxResults=100,
SortCriteria=[{"Field": "LastObservedAt", "SortOrder": "desc"}],
)
# Gets 100 findings from securityhub using the NextToken from a previous request
def getFindingsWithToken(self, token, filters={}):
return self.securityhub.get_findings(
Filters=filters,
NextToken=token,
MaxResults=100,
SortCriteria=[{"Field": "LastObservedAt", "SortOrder": "desc"}],
)
class AzureSentinelConnector:
def __init__(
self,
log_analytics_uri,
customer_id,
shared_key,
log_type,
queue_size=200,
bulks_number=10,
queue_size_bytes=25 * (2**20),
):
self.log_analytics_uri = log_analytics_uri
self.customer_id = customer_id
self.shared_key = shared_key
self.log_type = log_type
self.queue_size = queue_size
self.bulks_number = bulks_number
self.queue_size_bytes = queue_size_bytes
self._queue = []
self._bulks_list = []
self.successfull_sent_events_number = 0
self.failed_sent_events_number = 0
self.failedToSend = False
def send(self, event):
self._queue.append(event)
if len(self._queue) >= self.queue_size:
self.flush(force=False)
def flush(self, force=True):
self._bulks_list.append(self._queue)
if force:
self._flush_bulks()
else:
if len(self._bulks_list) >= self.bulks_number:
self._flush_bulks()
self._queue = []
def _flush_bulks(self):
jobs = []
for queue in self._bulks_list:
if queue:
queue_list = self._split_big_request(queue)
for q in queue_list:
jobs.append(
Thread(
target=self._post_data,
args=(
self.customer_id,
self.shared_key,
q,
self.log_type,
),
)
)
for job in jobs:
job.start()
for job in jobs:
job.join()
self._bulks_list = []
def __enter__(self):
pass
def __exit__(self, type, value, traceback):
self.flush()
def _build_signature(
self,
customer_id,
shared_key,
date,
content_length,
method,
content_type,
resource,
):
x_headers = "x-ms-date:" + date
string_to_hash = (
method
+ "\n"
+ str(content_length)
+ "\n"
+ content_type
+ "\n"
+ x_headers
+ "\n"
+ resource
)
bytes_to_hash = bytes(string_to_hash, encoding="utf-8")
decoded_key = base64.b64decode(shared_key)
encoded_hash = base64.b64encode(
hmac.new(decoded_key, bytes_to_hash, digestmod=hashlib.sha256).digest()
).decode()
authorization = "SharedKey {}:{}".format(customer_id, encoded_hash)
return authorization
def _post_data(self, customer_id, shared_key, body, log_type):
events_number = len(body)
body = json.dumps(body, sort_keys=True)
method = "POST"
content_type = "application/json"
resource = "/api/logs"
rfc1123date = datetime.datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S GMT")
content_length = len(body)
signature = self._build_signature(
customer_id,
shared_key,
rfc1123date,
content_length,
method,
content_type,
resource,
)
uri = self.log_analytics_uri + resource + "?api-version=2016-04-01"
headers = {
"content-type": content_type,
"Authorization": signature,
"Log-Type": log_type,
"x-ms-date": rfc1123date,
}
response = requests.post(uri, data=body, headers=headers)
if response.status_code >= 200 and response.status_code <= 299:
self.successfull_sent_events_number += events_number
self.failedToSend = False
else:
logging.error(
"Error during sending events to Azure Sentinel. Response code: {}".format(
response.status_code
)
)
self.failed_sent_events_number += events_number
self.failedToSend = True
def _check_size(self, queue):
data_bytes_len = len(json.dumps(queue).encode())
return data_bytes_len < self.queue_size_bytes
def _split_big_request(self, queue):
if self._check_size(queue):
return [queue]
else:
middle = int(len(queue) / 2)
queues_list = [queue[:middle], queue[middle:]]
return self._split_big_request(queues_list[0]) + self._split_big_request(
queues_list[1]
)