setup/setup-ci/security-scanner/amlsecscan.py (595 lines of code) (raw):
#!/usr/bin/env python3
import argparse
import base64
import datetime
import hashlib
import hmac
import json
import logging
import logging.handlers
import os
import re
import shutil
import subprocess
import sys
import time
import requests
from urllib.parse import urlparse
# User accounts running this script:
# _azbatch - user account running scripts during Azure ML Compute Instance creation - can sudo
# azureuser - user account running scripts after Azure ML Compute Instance creation - can sudo on Compute Instances created with rootAccess = true (default)
# root - does not need introduction...
_logger = logging.getLogger("amlsecscan")
_computer = os.environ["CI_NAME"]
_azure_ml_resource_id = (
"/" + urlparse(os.environ["MLFLOW_TRACKING_URI"]).path.split("/", 3)[3]
) # Get the ARM Resource ID of the Azure ML Workspace we are running on
# Configuration priority: 1) command-line parameters, 2) local config file, 3) global config file
_config_folder_path = "/home/azureuser/.amlsecscan"
_global_config_path = _config_folder_path + "/config.json"
_local_config_path = os.path.abspath(os.path.splitext(__file__)[0] + ".json")
# Replacement for azure.identity.DefaultAzureCredential().get_token since azure.identity is not available in the conda base environment and does not handle Azure ML's MSI
def _get_access_token(resource):
# Ensure the MSI environment variables are set (by default, they are set in shells when running in AML Studio Terminal but not when running in CRON)
if "MSI_ENDPOINT" not in os.environ or "MSI_SECRET" not in os.environ:
env_var = _get_auth_environment_variables()
os.environ["MSI_ENDPOINT"] = env_var["MSI_ENDPOINT"]
os.environ["MSI_SECRET"] = env_var["MSI_SECRET"]
url = f"{os.environ['MSI_ENDPOINT']}?resource={resource}&api-version=2017-09-01"
client_id = os.environ.get("DEFAULT_IDENTITY_CLIENT_ID", None)
if (
client_id is not None
and re.match(
"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}",
client_id,
re.IGNORECASE,
)
is not None
):
url = f"{url}&clientid={client_id}"
resp = requests.get(url, headers={"Secret": os.environ["MSI_SECRET"]})
resp.raise_for_status()
return resp.json()["access_token"]
def _run(command, check=True):
# To be compatible with Python 3.6 (default python for root user), 'text' and 'capture_output' cannot be used
try:
return subprocess.run(
command,
shell=True,
check=check,
universal_newlines=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
except subprocess.CalledProcessError as e:
_logger.exception(
f"Error: {e}\n stdout:\n{e.stdout}\n stderr:\n{e.stderr}"
)
raise
class StdOutTelemetry:
def send(self, log_type, data):
print(json.dumps({"table": log_type, "rows": data}))
class LogAnalyticsTelemetry:
def __init__(self, log_analytics_resource_id):
# Get the ARM Resource ID of the Log Analytics Workspace
if log_analytics_resource_id is None:
config_path = (
_local_config_path
if os.path.exists(_local_config_path)
else _global_config_path
if os.path.exists(_global_config_path)
else None
)
if config_path is not None:
_logger.debug(f"Loading configuration from {config_path}")
with open(config_path, "rt") as file:
config = json.load(file)
log_analytics_resource_id = config["logAnalyticsResourceId"]
self.log_analytics_resource_id = _sanitize_log_analytics_resource_id(
log_analytics_resource_id
)
# Get an AAD access token for ARM
access_token = _get_access_token("https://management.azure.com")
headers = {
"Authorization": "Bearer " + access_token
} # [SuppressMessage("Microsoft.Security", "CS001:SecretInline", Justification="No secret")]
# Get the Log Analytics Customer ID from ARM
response = requests.get(
"https://management.azure.com"
+ self.log_analytics_resource_id
+ "?api-version=2021-06-01",
headers=headers,
)
response.raise_for_status()
self.log_analytics_customer_id = response.json()["properties"]["customerId"]
# Get the Log Analytics Shared Key from ARM
response = requests.post(
"https://management.azure.com"
+ self.log_analytics_resource_id
+ "/sharedKeys?api-version=2020-08-01",
headers=headers,
)
response.raise_for_status()
self.log_analytics_shared_key = response.json()["primarySharedKey"]
_logger.debug(f"Azure ML Workspace ARM Resource ID: {_azure_ml_resource_id}")
_logger.debug(
f"Log Analytics Workspace ARM Resource ID: {self.log_analytics_resource_id}"
)
_logger.debug(f"Log Analytics Customer ID: {self.log_analytics_customer_id}")
# From: https://docs.microsoft.com/en-us/azure/azure-monitor/logs/data-collector-api#python-sample
def _build_signature(self, date, content_length, method, content_type, resource):
x_headers = "x-ms-date:" + date
string_to_hash = (
method
+ "\n"
+ str(content_length)
+ "\n"
+ content_type
+ "\n"
+ x_headers
+ "\n"
+ resource
)
bytes_to_hash = bytes(string_to_hash, encoding="utf-8")
decoded_key = base64.b64decode(self.log_analytics_shared_key)
encoded_hash = base64.b64encode(
hmac.new(decoded_key, bytes_to_hash, digestmod=hashlib.sha256).digest()
).decode()
authorization = "SharedKey {}:{}".format(
self.log_analytics_customer_id, encoded_hash
)
return authorization
def send(self, log_type, data):
body = json.dumps(data)
method = "POST"
content_type = "application/json"
resource = "/api/logs"
rfc1123date = datetime.datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S GMT")
content_length = len(body)
signature = self._build_signature(
rfc1123date, content_length, method, content_type, resource
)
headers = {
"content-type": content_type,
"Authorization": signature,
"Log-Type": log_type,
"x-ms-date": rfc1123date,
}
response = requests.post(
"https://"
+ self.log_analytics_customer_id
+ ".ods.opinsights.azure.com"
+ resource
+ "?api-version=2016-04-01",
data=body,
headers=headers,
)
response.raise_for_status()
_logger.info(
f"Sent {len(data)} telemetry row(s) to table {log_type} in Log Analytics workspace {self.log_analytics_resource_id}"
)
_logger.debug(f"Telemetry rows: {data}")
def _send_health(telemetry, type_, status=None, details=None):
telemetry.send(
"AmlSecurityComputeHealth",
[
{
"WorkspaceId": _azure_ml_resource_id,
"Computer": _computer,
"Type": type_, # Enum: Heartbeat, ScanMalware, ScanOsVulnerabilities, ScanPythonVulnerabilities
"Status": status
if status is not None
else "", # Enum: Started, Succeeded, Failed, ''
"Details": json.dumps(details) if details is not None else "",
}
],
)
def _send_assessment(telemetry, type_, findings, details=None):
telemetry.send(
"AmlSecurityComputeAssessments",
[
{
"WorkspaceId": _azure_ml_resource_id,
"Computer": _computer,
"Type": type_, # Enum: Malware, OsVulnerabilities, PythonVulnerabilities
"Status": "Healthy" if findings == 0 else "Unhealthy",
"Findings": findings,
"Details": json.dumps(details) if details is not None else "",
}
],
)
def _get_log_analytics_from_diagnostic_settings():
# Get an AAD access token for ARM
access_token = _get_access_token("https://management.azure.com")
headers = {
"Authorization": "Bearer " + access_token
} # [SuppressMessage("Microsoft.Security", "CS001:SecretInline", Justification="No secret")]
# List diagnostic settings on the Azure ML workspace
response = requests.get(
"https://management.azure.com"
+ _azure_ml_resource_id
+ "/providers/microsoft.insights/diagnosticSettings?api-version=2021-05-01-preview",
headers=headers,
)
response.raise_for_status()
# Select the first Log Analytics workspace
for settings in response.json()["value"]:
if "workspaceId" in settings["properties"]:
return settings["properties"]["workspaceId"]
return None
def _install(log_analytics_resource_id):
if os.geteuid() != 0:
raise Exception(
"Installation must be performed by the root user. Please run again using sudo."
)
_logger.debug(f"Creating folder {_config_folder_path}")
os.makedirs(_config_folder_path, exist_ok=True)
shutil.chown(_config_folder_path, "azureuser", "azureuser")
config = {"logAnalyticsResourceId": None}
# Load config file if present
if os.path.exists(_local_config_path):
_logger.debug(f"Loading configuration from {_local_config_path}")
with open(_local_config_path, "rt") as file:
config.update(json.load(file))
_logger.debug(
f"logAnalyticsResourceId after loading config file: {config['logAnalyticsResourceId']}"
)
# Set Log Analytics workspace ARM Resource ID if passed via command-line parameter
if log_analytics_resource_id is not None:
config["logAnalyticsResourceId"] = log_analytics_resource_id
_logger.debug(
f"logAnalyticsResourceId after setting command-line parameter: {config['logAnalyticsResourceId']}"
)
# Retrieve Log Analytics workspace ARM Resource ID from Azure ML diagnostic settings if
# provided neither via local config file nor command-line parameter
if config.get("logAnalyticsResourceId", None) is None:
config["logAnalyticsResourceId"] = _get_log_analytics_from_diagnostic_settings()
_logger.debug(
f"logAnalyticsResourceId after querying Azure ML diagnostic settings: {config['logAnalyticsResourceId']}"
)
# Sanitize the Log Analytics workspace ARM Resource ID
config["logAnalyticsResourceId"] = _sanitize_log_analytics_resource_id(
config["logAnalyticsResourceId"]
)
_logger.debug(f"Configuration: {config}")
_logger.info(f"Writing configuration file {_global_config_path}")
with open(_global_config_path, "wt") as file:
json.dump(config, file, indent=2)
shutil.chown(_global_config_path, "azureuser", "azureuser")
_logger.info("Installing Trivy")
_run(
"apt-get install -y --no-install-recommends --quiet wget apt-transport-https gnupg lsb-release"
)
_run(
"wget -qO - https://aquasecurity.github.io/trivy-repo/deb/public.key | apt-key add -"
)
_run(
"echo deb https://aquasecurity.github.io/trivy-repo/deb $(lsb_release -sc) main | tee -a /etc/apt/sources.list.d/trivy.list"
)
_run("apt-get update")
_run("apt-get install -y --no-install-recommends --quiet trivy")
script_path = _config_folder_path + "/run.sh"
_logger.info(f"Writing script file {script_path}")
with open(script_path, "wt") as file:
file.write(
f"""#!/bin/bash
set -e
exec 1> >(logger -s -t AMLSECSCAN) 2>&1
# Limit CPU usage to 20% and reduce priority (note: the configuration is not persisted during reboot)
if [ ! -d /sys/fs/cgroup/cpu/amlsecscan ]
then
mkdir -p /sys/fs/cgroup/cpu/amlsecscan
echo 100000 | tee /sys/fs/cgroup/cpu/amlsecscan/cpu.cfs_period_us > /dev/null
echo 20000 | tee /sys/fs/cgroup/cpu/amlsecscan/cpu.cfs_quota_us > /dev/null
echo 5 | tee /sys/fs/cgroup/cpu/amlsecscan/cpu.shares > /dev/null
fi
echo $$ | tee /sys/fs/cgroup/cpu/amlsecscan/tasks > /dev/null
nice -n 19 python3 {os.path.abspath(__file__)} $1 $2 $3 $4 $5
"""
)
os.chmod(script_path, 0o0755)
_logger.info(f"Writing crontab file /etc/cron.d/amlsecscan")
with open("/etc/cron.d/amlsecscan", "wt") as file:
file.write(
f"""*/10 * * * * root {script_path} heartbeat
37 5 * * * root {script_path} scan all
@reboot root sleep 600 && {script_path} scan all
"""
)
os.chmod("/etc/cron.d/amlsecscan", 0o0644)
def _uninstall():
if os.geteuid() != 0:
raise Exception(
"Uninstallation must be performed by the root user. Please run again using sudo."
)
_logger.info(f"Deleting crontab file /etc/cron.d/amlsecscan")
_run("rm -f /etc/cron.d/amlsecscan")
_logger.info(f"Deleting folder {_config_folder_path}")
shutil.rmtree(_config_folder_path, ignore_errors=True)
def _sanitize_log_analytics_resource_id(log_analytics_resource_id):
if log_analytics_resource_id is None:
raise ValueError(
"Log Analytics Workspace ARM Resource ID missing. Please provide it either via config file, command-line parameter, or Azure ML diagnostic settings."
)
log_analytics_resource_id = log_analytics_resource_id.strip()
if len(log_analytics_resource_id.split("/")) != 9:
raise ValueError(
"Log Analytics Workspace ARM Resource ID format should be /subscriptions/{subscription}/resourceGroups/{resource_group}/providers/Microsoft.OperationalInsights/workspaces/{workspace} instead of '"
+ log_analytics_resource_id
+ "'"
)
return log_analytics_resource_id
def _get_auth_environment_variables():
out = _run("cat /etc/environment.sso")
return {
pair[0]: pair[1]
for pair in [line.split("=", 2) for line in out.stdout.splitlines()]
}
def _parse_clamav_stdout(stdout):
files = []
details = {}
findings = 0
for line in stdout.splitlines():
match = re.match(r"^(.+?):\s*(.+?)\s+FOUND", line)
if match is not None:
files.append({"path": match.group(1), "malwareType": match.group(2)})
continue
match = re.match(r"Infected files:\s*(\d+)", line)
if match is not None:
findings = int(match.group(1))
continue
match = re.match(r"Known viruses:\s*(\d+)", line)
if match is not None:
details["knownViruses"] = int(match.group(1))
continue
match = re.match(r"Engine version:\s*(.+)", line)
if match is not None:
details["engineVersion"] = match.group(1)
continue
match = re.match(r"Scanned files:\s*(\d+)", line)
if match is not None:
details["scannedFiles"] = int(match.group(1))
continue
match = re.match(r"Scanned directories:\s*(\d+)", line)
if match is not None:
details["scannedDirectories"] = int(match.group(1))
continue
if findings != len(files):
raise Exception(
f"Failed to parse ClamAV stdout (findings: {findings}, files: {len(files)})"
)
if len(files) > 0:
details["files"] = files
return (findings, details)
def _parse_trivy_results(trivy_scan_path):
findings_os = []
findings_python = []
with open(trivy_scan_path, "rt") as file:
data = json.load(file)
for result in data["Results"]:
if result["Class"] == "os-pkgs":
for vulnerability in result.get("Vulnerabilities", []):
findings_os.append(
{
"title": vulnerability.get(
"Title",
vulnerability["PkgName"]
+ " "
+ vulnerability["VulnerabilityID"],
),
"packageName": vulnerability["PkgName"],
"packageVersion": vulnerability["InstalledVersion"],
"CVE": vulnerability["VulnerabilityID"],
"severity": vulnerability["Severity"],
}
)
elif result["Class"] == "lang-pkgs" and result["Type"] == "pip":
for vulnerability in result.get("Vulnerabilities", []):
findings_python.append(
{
"title": vulnerability.get(
"Title",
vulnerability["PkgName"]
+ " "
+ vulnerability["VulnerabilityID"],
),
"packageName": vulnerability["PkgName"],
"packageVersion": vulnerability["InstalledVersion"],
"file": result["Target"],
"CVE": vulnerability["VulnerabilityID"],
"severity": vulnerability["Severity"],
}
)
else:
_logger.warning(
f"Skipping unhandled vulnerability of class {result['Class']} and type {result['Type']} for file {result['Target']}. "
)
return (findings_os, findings_python)
# Limit the finding list to top 50 by severity so that the Log Analytics limit of 32K string length is not hit (which truncates JSON strings and makes them invalid)
def _filter_trivy_results(findings):
return sorted(
findings,
key=lambda x: 0
if x["severity"] == "CRITICAL"
else 1
if x["severity"] == "HIGH"
else 2,
)[:50]
def _scan_vulnerabilities(telemetry):
start_time = time.time()
_send_health(telemetry, "ScanVulnerabilities", "Started")
try:
shutil.rmtree(f"{_config_folder_path}/anaconda", ignore_errors=True)
for env_name in (
entry.name for entry in os.scandir("/anaconda/envs") if entry.is_dir()
):
_logger.info(
f"Saving pip freeze of conda environment {env_name} to {_config_folder_path}/anaconda/{env_name}/requirements.txt"
)
os.makedirs(f"{_config_folder_path}/anaconda/{env_name}", exist_ok=True)
_run(
f"/anaconda/envs/{env_name}/bin/python3 -m pip freeze > {_config_folder_path}/anaconda/{env_name}/requirements.txt"
)
_logger.info("Running Trivy scan")
_run(
f"/usr/local/bin/trivy filesystem --format json --output {_config_folder_path}/trivy.json --security-checks vuln --severity HIGH,CRITICAL --ignore-unfixed /"
)
findings_os, findings_python = _parse_trivy_results(
f"{_config_folder_path}/trivy.json"
)
_send_assessment(
telemetry,
"OsVulnerabilities",
len(findings_os),
{"findings": _filter_trivy_results(findings_os)}
if len(findings_os) > 0
else None,
)
_send_assessment(
telemetry,
"PythonVulnerabilities",
len(findings_python),
{"findings": _filter_trivy_results(findings_python)}
if len(findings_python) > 0
else None,
)
_send_health(
telemetry,
"ScanVulnerabilities",
"Succeeded",
{"elapsedTimeInS": time.time() - start_time},
)
return True
except subprocess.CalledProcessError as e:
_send_health(
telemetry,
"ScanVulnerabilities",
"Failed",
{
"error": str(e),
"stdout": e.stdout,
"stderr": e.stderr,
"elapsedTimeInS": time.time() - start_time,
},
)
return False
except Exception as e:
_logger.exception(f"Error: {e}")
_send_health(
telemetry,
"ScanVulnerabilities",
"Failed",
{"error": str(e), "elapsedTimeInS": time.time() - start_time},
)
return False
def _scan_malware(telemetry):
start_time = time.time()
_send_health(telemetry, "ScanMalware", "Started")
try:
# Run ClamAV (with AzSecPack malware definitions if present)
database_option = (
"-d /var/lib/azsec-clamav"
if os.path.exists("/var/lib/azsec-clamav")
else ""
)
command = (
f"clamscan {database_option} -r -i --exclude-dir=^/sys/ /bin /boot /home /lib /lib64 /opt /root /sbin /anaconda",
)
_logger.info(f"Running: {command}")
out = _run(command, check=False)
# returncode:
# == 0 -> clamscan completed scan without finding malware
# == 1 -> clamscan completed scan with malware found
# >= 2 -> clamscan failed to scan
if out.returncode >= 2:
raise Exception(f"Scan failed with exit code {out.returncode}")
findings, details = _parse_clamav_stdout(out.stdout)
if findings == 0 and out.returncode != 0:
raise Exception(
f"Failed to parse ClamAV stdout (findings: {findings}, exit code: {out.returncode})"
)
_send_assessment(telemetry, "Malware", findings, details)
_send_health(
telemetry,
"ScanMalware",
"Succeeded",
{"elapsedTimeInS": time.time() - start_time},
)
return True
except subprocess.CalledProcessError as e:
_send_health(
telemetry,
"ScanMalware",
"Failed",
{
"error": str(e),
"stdout": e.stdout,
"stderr": e.stderr,
"elapsedTimeInS": time.time() - start_time,
},
)
return False
except Exception as e:
_logger.exception(e)
_send_health(
telemetry,
"ScanMalware",
"Failed",
{"error": str(e), "elapsedTimeInS": time.time() - start_time},
)
return False
def _add_common_arguments(parser):
parser.add_argument(
"-la",
"--log-analytics-resource-id",
help="ARM Resource ID of the Log Analytics workspace to log telemetry to",
dest="log_analytics_resource_id",
)
parser.add_argument(
"-ll",
"--log-level",
help="level of log messages to display (default: INFO)",
dest="log_level",
choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"],
)
parser.add_argument(
"-o",
"--output",
help="output (default: log-analytics)",
dest="output",
choices=["log-analytics", "stdout"],
)
if __name__ == "__main__":
# Logging to stdout (forwarded to syslog in run.sh)
_logger.setLevel(logging.INFO)
_logger.addHandler(logging.StreamHandler(stream=sys.stdout))
try:
# Command-line parser
parser = argparse.ArgumentParser(
description="Azure ML Compute Security Scanner"
)
subparsers = parser.add_subparsers(dest="command")
# Command: "install"
parser_install = subparsers.add_parser(
"install",
help="Install dependencies and start scheduled scans. Must be run as root (use sudo).",
)
_add_common_arguments(parser_install)
# Command: "uninstall"
parser_uninstall = subparsers.add_parser(
"uninstall", help="Remove scheduled scans. Must be run as root (use sudo)."
)
_add_common_arguments(parser_uninstall)
# Command: "heartbeat"
parser_heartbeat = subparsers.add_parser(
"heartbeat", help="Emit a telemetry heartbeat"
)
_add_common_arguments(parser_heartbeat)
# Command: "scan"
parser_scan = subparsers.add_parser("scan", help="Run security scans")
subparsers_scan = parser_scan.add_subparsers(dest="scan_type")
# Command: "scan all"
parser_scan_all = subparsers_scan.add_parser(
"all", help="Run all security scans"
)
_add_common_arguments(parser_scan_all)
# Command: "scan malware"
parser_scan_malware = subparsers_scan.add_parser(
"malware", help="Scan for malware"
)
_add_common_arguments(parser_scan_malware)
# Command: "scan vulnerabilities"
parser_scan_vulnerabilities = subparsers_scan.add_parser(
"vulnerabilities", help="Scan for OS and Python vulnerabilities"
)
_add_common_arguments(parser_scan_vulnerabilities)
args = parser.parse_args()
if args.command is None:
parser.print_help()
exit(1)
if "log_level" in args and args.log_level is not None:
_logger.setLevel(getattr(logging, args.log_level))
if args.command == "install":
_install(args.log_analytics_resource_id)
elif args.command == "uninstall":
_uninstall()
elif args.command == "heartbeat":
telemetry = (
StdOutTelemetry()
if args.output == "stdout"
else LogAnalyticsTelemetry(args.log_analytics_resource_id)
)
_send_health(telemetry, "Heartbeat")
elif args.command == "scan":
if args.scan_type is None:
parser.print_help()
exit(1)
telemetry = (
StdOutTelemetry()
if args.output == "stdout"
else LogAnalyticsTelemetry(args.log_analytics_resource_id)
)
if args.scan_type == "all":
success0 = _scan_vulnerabilities(telemetry)
success1 = _scan_malware(telemetry)
exit(0 if success0 and success1 else 2)
# TODO: Python vulns
elif args.scan_type == "vulnerabilities":
success = _scan_vulnerabilities(telemetry)
exit(0 if success else 2)
elif args.scan_type == "malware":
success = _scan_malware(telemetry)
exit(0 if success else 2)
else:
raise ValueError(f"Insupported scan type '{args.scan_type}'")
else:
raise ValueError(f"Insupported command '{args.command}'")
except Exception as e:
_logger.critical(f"Unhandled exception: {e}")
raise