emails/utils.py (396 lines of code) (raw):

from __future__ import annotations import base64 import contextlib import json import logging import pathlib import re import zlib from collections.abc import Callable from email.errors import HeaderParseError, InvalidHeaderDefect from email.headerregistry import Address, AddressHeader from email.message import EmailMessage from email.utils import formataddr, parseaddr from functools import cache from typing import Any, Literal, TypeVar, cast from urllib.parse import quote_plus, urlparse from django.conf import settings from django.contrib.auth.models import Group, User from django.template.defaultfilters import linebreaksbr, urlize from django.template.loader import render_to_string from django.utils.text import Truncator import jwcrypto.jwe import jwcrypto.jwk import markus import requests from allauth.socialaccount.models import SocialAccount from botocore.exceptions import ClientError from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand from mypy_boto3_ses.type_defs import ContentTypeDef, SendRawEmailResponseTypeDef from privaterelay.plans import get_bundle_country_language_mapping from privaterelay.utils import get_countries_info_from_lang_and_mapping from .apps import s3_client, ses_client logger = logging.getLogger("events") info_logger = logging.getLogger("eventsinfo") study_logger = logging.getLogger("studymetrics") metrics = markus.get_metrics() shavar_prod_lists_url = ( "https://raw.githubusercontent.com/mozilla-services/shavar-prod-lists/" "master/disconnect-blacklist.json" ) EMAILS_FOLDER_PATH = pathlib.Path(__file__).parent TRACKER_FOLDER_PATH = EMAILS_FOLDER_PATH / "tracker_lists" def ses_message_props(data: str) -> ContentTypeDef: return {"Charset": "UTF-8", "Data": data} def get_domains_from_settings() -> ( dict[Literal["RELAY_FIREFOX_DOMAIN", "MOZMAIL_DOMAIN"], str] ): # HACK: detect if code is running in django tests if "testserver" in settings.ALLOWED_HOSTS: return {"RELAY_FIREFOX_DOMAIN": "default.com", "MOZMAIL_DOMAIN": "test.com"} return { "RELAY_FIREFOX_DOMAIN": settings.RELAY_FIREFOX_DOMAIN, "MOZMAIL_DOMAIN": settings.MOZMAIL_DOMAIN, } def get_trackers(level): category = "Email" tracker_list_name = "level-one-trackers" if level == 2: category = "EmailAggressive" tracker_list_name = "level-two-trackers" trackers = [] file_name = f"{tracker_list_name}.json" try: with open(TRACKER_FOLDER_PATH / file_name) as f: trackers = json.load(f) except FileNotFoundError: trackers = download_trackers(shavar_prod_lists_url, category) store_trackers(trackers, TRACKER_FOLDER_PATH, file_name) return trackers def download_trackers(repo_url, category="Email"): # email tracker lists from shavar-prod-list as per agreed use under license: resp = requests.get(repo_url, timeout=10) json_resp = resp.json() formatted_trackers = json_resp["categories"][category] trackers = [] for entity in formatted_trackers: for _, resources in entity.items(): for _, domains in resources.items(): trackers.extend(domains) return trackers def store_trackers(trackers, path, file_name): with open(path / file_name, "w+") as f: json.dump(trackers, f, indent=4) @cache def general_trackers(): return get_trackers(level=1) @cache def strict_trackers(): return get_trackers(level=2) _TimedFunction = TypeVar("_TimedFunction", bound=Callable[..., Any]) def time_if_enabled(name: str) -> Callable[[_TimedFunction], _TimedFunction]: def timing_decorator(func: _TimedFunction) -> _TimedFunction: def func_wrapper(*args, **kwargs): ctx_manager = ( metrics.timer(name) if settings.STATSD_ENABLED else contextlib.nullcontext() ) with ctx_manager: return func(*args, **kwargs) return cast(_TimedFunction, func_wrapper) return timing_decorator def incr_if_enabled(name, value=1, tags=None): if settings.STATSD_ENABLED: metrics.incr(name, value, tags) def histogram_if_enabled(name, value, tags=None): if settings.STATSD_ENABLED: metrics.histogram(name, value=value, tags=tags) def gauge_if_enabled(name, value, tags=None): if settings.STATSD_ENABLED: metrics.gauge(name, value, tags) def get_email_domain_from_settings() -> str: email_network_locality = str(urlparse(settings.SITE_ORIGIN).netloc) # on dev server we need to add "mail" prefix # because we can’t publish MX records on Heroku if settings.RELAY_CHANNEL == "dev": email_network_locality = f"mail.{email_network_locality}" return email_network_locality def parse_email_header(header_value: str) -> list[tuple[str, str]]: """ Extract the (display name, email address) pairs from a header value. This is useful when working with header values provided by a AWS SES delivery notification. email.utils.parseaddr() works with well-formed emails, but fails in cases with badly formed emails where an email address could still be extracted. """ address_list = AddressHeader.value_parser(header_value) pairs: list[tuple[str, str]] = [] for address in address_list.addresses: for mailbox in address.all_mailboxes: addr_spec = mailbox.addr_spec if addr_spec and addr_spec.count("@") == 1: pairs.append((mailbox.display_name or "", addr_spec)) return pairs def _get_hero_img_src(lang_code): img_locale = "en" avail_l10n_image_codes = [ "cs", "de", "en", "es", "fi", "fr", "hu", "id", "it", "ja", "nl", "pt", "ru", "sv", "zh", ] major_lang = lang_code.split("-")[0] if major_lang in avail_l10n_image_codes: img_locale = major_lang if not settings.SITE_ORIGIN: raise ValueError("settings.SITE_ORIGIN must have a value") return ( settings.SITE_ORIGIN + f"/static/images/email-images/first-time-user/hero-image-{img_locale}.png" ) def get_welcome_email(user: User, format: str) -> str: sa = SocialAccount.objects.get(user=user) bundle_plans = get_countries_info_from_lang_and_mapping( sa.extra_data.get("locale", "en"), get_bundle_country_language_mapping() ) lang_code = user.profile.language hero_img_src = _get_hero_img_src(lang_code) return render_to_string( f"emails/first_time_user.{format}", { "in_bundle_country": bundle_plans["available_in_country"], "SITE_ORIGIN": settings.SITE_ORIGIN, "hero_img_src": hero_img_src, "language": lang_code, }, ) @time_if_enabled("ses_send_raw_email") def ses_send_raw_email( source_address: str, destination_address: str, message: EmailMessage, ) -> SendRawEmailResponseTypeDef: client = ses_client() if client is None: raise ValueError("client must have a value") if not settings.AWS_SES_CONFIGSET: raise ValueError("settings.AWS_SES_CONFIGSET must have a value") data = message.as_string() try: ses_response = client.send_raw_email( Source=source_address, Destinations=[destination_address], RawMessage={"Data": data}, ConfigurationSetName=settings.AWS_SES_CONFIGSET, ) incr_if_enabled("ses_send_raw_email", 1) return ses_response except ClientError as e: logger.error("ses_client_error_raw_email", extra=e.response["Error"]) raise def urlize_and_linebreaks(text, autoescape=True): return linebreaksbr(urlize(text, autoescape=autoescape), autoescape=autoescape) def get_reply_to_address(premium: bool = True) -> str: """Return the address that relays replies.""" if premium: _, reply_to_address = parseaddr( "replies@{}".format(get_domains_from_settings().get("RELAY_FIREFOX_DOMAIN")) ) else: _, reply_to_address = parseaddr(settings.RELAY_FROM_ADDRESS) return reply_to_address def truncate(max_length: int, value: str) -> str: """ Truncate a string to a maximum length. If the value is all ASCII, the truncation suffix will be ... If the value is non-ASCII, the truncation suffix will be … (Unicode ellipsis) """ if len(value) <= max_length: return value ellipsis = "..." # ASCII Ellipsis try: value.encode("ascii") except UnicodeEncodeError: ellipsis = "…" return Truncator(value).chars(max_length, truncate=ellipsis) class InvalidFromHeader(Exception): pass def generate_from_header(original_from_address: str, relay_mask: str) -> str: """ Return a From: header str using the original sender and a display name that refers to Relay. This format was introduced in June 2023 with MPP-2117. """ oneline_from_address = ( original_from_address.replace("\u2028", "").replace("\r", "").replace("\n", "") ) display_name, original_address = parseaddr(oneline_from_address) try: parsed_address = Address(addr_spec=original_address) except (InvalidHeaderDefect, IndexError, HeaderParseError) as e: # TODO: MPP-3407, MPP-3417 - Determine how to handle these raise InvalidFromHeader from e # Truncate the display name to 71 characters, so the sender portion fits on the # first line of a multi-line "From:" header, if it is ASCII. A utf-8 encoded header # will be 226 chars, still below the 998 limit of RFC 5322 2.1.1. max_length = 71 if display_name: short_name = truncate(max_length, display_name) short_address = truncate(max_length, parsed_address.addr_spec) sender = f"{short_name} <{short_address}>" else: # Use the email address if the display name was not originally set display_name = parsed_address.addr_spec sender = truncate(max_length, display_name) return formataddr((f"{sender} [via Relay]", relay_mask)) def get_message_id_bytes(message_id_str: str) -> bytes: message_id = message_id_str.split("@", 1)[0].rsplit("<", 1)[-1].strip() return message_id.encode() def b64_lookup_key(lookup_key: bytes) -> str: return base64.urlsafe_b64encode(lookup_key).decode("ascii") def derive_reply_keys(message_id: bytes) -> tuple[bytes, bytes]: """Derive the lookup key and encryption key from an aliased message id.""" algorithm = hashes.SHA256() hkdf = HKDFExpand(algorithm=algorithm, length=16, info=b"replay replies lookup key") lookup_key = hkdf.derive(message_id) hkdf = HKDFExpand( algorithm=algorithm, length=32, info=b"replay replies encryption key" ) encryption_key = hkdf.derive(message_id) return (lookup_key, encryption_key) def encrypt_reply_metadata(key: bytes, payload: dict[str, str]) -> str: """Encrypt the given payload into a JWE, using the given key.""" # This is a bit dumb, we have to base64-encode the key in order to load it :-/ k = jwcrypto.jwk.JWK( kty="oct", k=base64.urlsafe_b64encode(key).rstrip(b"=").decode("ascii") ) e = jwcrypto.jwe.JWE( json.dumps(payload), json.dumps({"alg": "dir", "enc": "A256GCM"}), recipient=k ) return cast(str, e.serialize(compact=True)) def decrypt_reply_metadata(key, jwe): """Decrypt the given JWE into a json payload, using the given key.""" # This is a bit dumb, we have to base64-encode the key in order to load it :-/ k = jwcrypto.jwk.JWK( kty="oct", k=base64.urlsafe_b64encode(key).rstrip(b"=").decode("ascii") ) e = jwcrypto.jwe.JWE() e.deserialize(jwe) e.decrypt(k) return e.plaintext def _get_bucket_and_key_from_s3_json(message_json): # Only Received notifications have S3-stored data notification_type = message_json.get("notificationType") if notification_type != "Received": return None, None if "receipt" in message_json and "action" in message_json["receipt"]: message_json_receipt = message_json["receipt"] else: logger.error( "sns_inbound_message_without_receipt", extra={"message_json_keys": message_json.keys()}, ) return None, None bucket = None object_key = None try: if "S3" in message_json_receipt["action"]["type"]: bucket = message_json_receipt["action"]["bucketName"] object_key = message_json_receipt["action"]["objectKey"] except (KeyError, TypeError): logger.error( "sns_inbound_message_receipt_malformed", extra={"receipt_action": message_json_receipt["action"]}, ) return bucket, object_key @time_if_enabled("s3_get_message_content") def get_message_content_from_s3(bucket, object_key): if bucket and object_key: client = s3_client() if client is None: raise ValueError("client must not be None") streamed_s3_object = client.get_object(Bucket=bucket, Key=object_key).get( "Body" ) return streamed_s3_object.read() @time_if_enabled("s3_remove_message_from") def remove_message_from_s3(bucket, object_key): if bucket is None or object_key is None: return False try: client = s3_client() if client is None: raise ValueError("client must not be None") response = client.delete_object(Bucket=bucket, Key=object_key) return response.get("DeleteMarker") except ClientError as e: if e.response["Error"].get("Code", "") == "NoSuchKey": logger.error("s3_delete_object_does_not_exist", extra=e.response["Error"]) else: logger.error("s3_client_error_delete_email", extra=e.response["Error"]) incr_if_enabled("message_not_removed_from_s3", 1) return False def set_user_group(user): if "@" not in user.email: return None email_domain = user.email.split("@")[1] group_attribute = { "mozilla.com": "mozilla_corporation", "mozillafoundation.org": "mozilla_foundation", "getpocket.com": "pocket", } group_name = group_attribute.get(email_domain) if not group_name: return None internal_group_qs = Group.objects.filter(name=group_name) internal_group = internal_group_qs.first() if internal_group is None: return None internal_group.user_set.add(user) def convert_domains_to_regex_patterns(domain_pattern): return r"""(["'])(\S*://(\S*\.)*""" + re.escape(domain_pattern) + r"\S*)\1" def count_tracker(html_content, trackers): tracker_total = 0 details = {} # html_content needs to be str for count() for tracker in trackers: pattern = convert_domains_to_regex_patterns(tracker) html_content, count = re.subn(pattern, "", html_content) if count: tracker_total += count details[tracker] = count return {"count": tracker_total, "trackers": details} def count_all_trackers(html_content): general_detail = count_tracker(html_content, general_trackers()) strict_detail = count_tracker(html_content, strict_trackers()) incr_if_enabled("tracker.general_count", general_detail["count"]) incr_if_enabled("tracker.strict_count", strict_detail["count"]) study_logger.info( "email_tracker_summary", extra={"level_one": general_detail, "level_two": strict_detail}, ) def remove_trackers(html_content, from_address, datetime_now, level="general"): trackers = general_trackers() if level == "general" else strict_trackers() tracker_removed = 0 changed_content = html_content for tracker in trackers: pattern = convert_domains_to_regex_patterns(tracker) def convert_to_tracker_warning_link(matchobj): quote, original_link, _ = matchobj.groups() tracker_link_details = { "sender": from_address, "received_at": datetime_now, "original_link": original_link, } anchor = quote_plus(json.dumps(tracker_link_details, separators=(",", ":"))) url = f"{settings.SITE_ORIGIN}/contains-tracker-warning/#{anchor}" return f"{quote}{url}{quote}" changed_content, matched = re.subn( pattern, convert_to_tracker_warning_link, changed_content ) tracker_removed += matched level_one_detail = count_tracker(html_content, general_trackers()) level_two_detail = count_tracker(html_content, strict_trackers()) tracker_details = { "tracker_removed": tracker_removed, "level_one": level_one_detail, } logger_details = {"level": level, "level_two": level_two_detail} logger_details.update(tracker_details) info_logger.info( "email_tracker_summary", extra=logger_details, ) return changed_content, tracker_details def encode_dict_gza85(data: dict[str, Any]) -> str: """ Encode a dict to the compressed Ascii85 format The dict will be JSON-encoded will be compressed, Ascii85-encoded with padding, and split by newlines into 1024-bytes chunks. This can be used to ensure it fits into a GCP log entry, which has a 64KB limit per label value. """ return base64.a85encode( zlib.compress(json.dumps(data).encode()), wrapcol=1024, pad=True ).decode("ascii") def decode_dict_gza85(encoded_data: str) -> dict[str, Any]: """Decode a dict encoded with _encode_dict_gza85.""" data = json.loads( zlib.decompress(base64.a85decode(encoded_data.encode("ascii"))).decode() ) if not isinstance(data, dict): raise ValueError("Encoded data is not a dict") if any(not isinstance(key, str) for key in data): raise ValueError("Encoded data has non-str key") return data