privaterelay/utils.py (170 lines of code) (raw):
from __future__ import annotations
import json
import logging
import random
from collections.abc import Callable
from decimal import Decimal
from functools import cache, wraps
from pathlib import Path
from typing import TYPE_CHECKING, ParamSpec, TypedDict, TypeVar, cast
from django.conf import settings
from django.contrib.auth.models import AbstractBaseUser
from django.http import Http404, HttpRequest
from waffle import get_waffle_flag_model
from waffle.models import logger as waffle_logger
from waffle.utils import get_cache as get_waffle_cache
from waffle.utils import get_setting as get_waffle_setting
from privaterelay.country_utils import (
AcceptLanguageError,
_get_cc_from_lang,
_get_cc_from_request,
guess_country_from_accept_lang,
)
from privaterelay.sp3_plans import SP3PlanCountryLangMapping
from .plans import (
CountryStr,
LanguageStr,
PeriodStr,
PlanCountryLangMapping,
get_premium_country_language_mapping,
)
if TYPE_CHECKING:
from .glean_interface import RelayGleanLogger
info_logger = logging.getLogger("eventsinfo")
class CountryInfo(TypedDict):
country_code: str
countries: list[CountryStr]
available_in_country: bool
plan_country_lang_mapping: PlanCountryLangMapping | SP3PlanCountryLangMapping
def get_countries_info_from_request_and_mapping(
request: HttpRequest, mapping: PlanCountryLangMapping | SP3PlanCountryLangMapping
) -> CountryInfo:
country_code = _get_cc_from_request(request)
countries = sorted(mapping.keys())
available_in_country = country_code in countries
return {
"country_code": country_code,
"countries": countries,
"available_in_country": available_in_country,
"plan_country_lang_mapping": mapping,
}
def get_countries_info_from_lang_and_mapping(
accept_lang: str, mapping: PlanCountryLangMapping
) -> CountryInfo:
country_code = _get_cc_from_lang(accept_lang)
countries = sorted(mapping.keys())
available_in_country = country_code in countries
return {
"country_code": country_code,
"countries": countries,
"available_in_country": available_in_country,
"plan_country_lang_mapping": mapping,
}
def get_subplat_upgrade_link_by_language(
accept_language: str, period: PeriodStr = "yearly"
) -> str:
try:
country_str = guess_country_from_accept_lang(accept_language)
country = cast(CountryStr, country_str)
except AcceptLanguageError:
country = "US"
language_str = accept_language.split("-")[0].lower()
language = cast(LanguageStr, language_str)
country_lang_mapping = get_premium_country_language_mapping()
country_details = country_lang_mapping.get(country, country_lang_mapping["US"])
if language in country_details:
plan = country_details[language][period]
else:
first_key = list(country_details.keys())[0]
plan = country_details[first_key][period]
return (
f"{settings.FXA_BASE_ORIGIN}/subscriptions/products/"
f"{settings.PERIODICAL_PREMIUM_PROD_ID}?plan={plan['id']}"
)
# Generics for defining function decorators
# https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators
_Params = ParamSpec("_Params")
_RetVal = TypeVar("_RetVal")
def enable_or_404(
check_function: Callable[[], bool],
message: str = "This conditional view is disabled.",
) -> Callable[[Callable[_Params, _RetVal]], Callable[_Params, _RetVal]]:
"""
Returns decorator that enables a view if a check function passes,
otherwise returns a 404.
Usage:
def percent_1():
import random
return random.randint(1, 100) == 1
@enable_if(percent_1)
def lucky_view(request):
# 1 in 100 chance of getting here
# 99 in 100 chance of 404
"""
def decorator(func: Callable[_Params, _RetVal]) -> Callable[_Params, _RetVal]:
@wraps(func)
def inner(*args: _Params.args, **kwargs: _Params.kwargs) -> _RetVal:
if check_function():
return func(*args, **kwargs)
else:
raise Http404(message) # Display a message with DEBUG=True
return inner
return decorator
def enable_if_setting(
setting_name: str,
message_fmt: str = "This view is disabled because {setting_name} is False",
) -> Callable[[Callable[_Params, _RetVal]], Callable[_Params, _RetVal]]:
"""
Returns decorator that enables a view if a setting is truthy, otherwise
returns a 404.
Usage:
@enable_if_setting("DEBUG")
def debug_only_view(request):
# DEBUG == True
Or in URLS:
path(
"developer_info",
enable_if_setting("DEBUG")(debug_only_view)
),
name="developer-info",
),
"""
def setting_is_truthy() -> bool:
return bool(getattr(settings, setting_name))
return enable_or_404(
setting_is_truthy, message_fmt.format(setting_name=setting_name)
)
def flag_is_active_in_task(flag_name: str, user: AbstractBaseUser | None) -> bool:
"""
Test if a flag is active in a task (not in a web request).
This mirrors AbstractBaseFlag.is_active, replicating these checks:
* Logs missing flags, if configured
* Creates missing flags, if configured
* Returns default for missing flags
* Checks flag.everyone
* Checks flag.users and flag.groups, if a user is passed
* Returns random results for flag.percent
It does not include:
* Overriding a flag with a query parameter
* Persisting a flag in a cookie (includes percent flags)
* Language-specific overrides (could be added later)
* Read-only mode for percent flags
When using this function, use the @override_flag decorator in tests, rather
than manually creating flags in the database.
"""
flag = get_waffle_flag_model().get(flag_name)
if not flag.pk:
log_level = get_waffle_setting("LOG_MISSING_FLAGS")
if log_level:
waffle_logger.log(log_level, "Flag %s not found", flag_name)
if get_waffle_setting("CREATE_MISSING_FLAGS"):
flag, _created = get_waffle_flag_model().objects.get_or_create(
name=flag_name,
defaults={"everyone": get_waffle_setting("FLAG_DEFAULT")},
)
cache = get_waffle_cache()
cache.set(flag._cache_key(flag.name), flag)
return bool(get_waffle_setting("FLAG_DEFAULT"))
# Removed - check for override as request query parameter
if flag.everyone:
return True
elif flag.everyone is False:
return False
# Removed - check for testing override in request query or cookie
# Removed - check for language-specific override
if user is not None:
active_for_user = flag.is_active_for_user(user)
if active_for_user is not None:
return bool(active_for_user)
if flag.percent and flag.percent > 0:
# Removed - check for waffles attribute of request
# Removed - check for cookie setting for flag
# Removed - check for read-only mode
if Decimal(str(random.uniform(0, 100))) <= flag.percent: # noqa: S311
# Removed - setting the flag for future checks
return True
return False
class VersionInfo(TypedDict):
source: str
version: str
commit: str
build: str
@cache
def get_version_info(base_dir: str | Path | None = None) -> VersionInfo:
"""Return version information written by build process."""
if base_dir is None:
base_path = Path(settings.BASE_DIR)
else:
base_path = Path(base_dir)
version_json_path = base_path / "version.json"
info = {}
if version_json_path.exists():
with version_json_path.open() as version_file:
try:
info = json.load(version_file)
except ValueError:
pass
if not hasattr(info, "get"):
info = {}
version_info = VersionInfo(
source=info.get("source", "https://github.com/mozilla/fx-private-relay"),
version=info.get("version", "unknown"),
commit=info.get("commit", "unknown"),
build=info.get("build", "not built"),
)
return version_info
@cache
def glean_logger() -> RelayGleanLogger:
from .glean_interface import RelayGleanLogger
version_info = get_version_info()
return RelayGleanLogger(
application_id="relay-backend",
app_display_version=version_info["version"],
channel=settings.RELAY_CHANNEL,
)