# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pyre-unsafe

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from thrift.transport.TSocket import *
from thrift.transport.TTransport import *

# workaround for a python bug.  see http://bugs.python.org/issue8484
import hashlib
import logging
import socket
import ssl
import sys
import traceback


def _detect_legacy_ssl() -> bool:
    """
    Checks whether or not we have the newer Python >= 2.7.9,3.2+ attributes
    necessary to properly configure TLS settings
    """
    required_attributes = [
        "SSLContext",
        "OP_NO_SSLv2",
        "OP_NO_SSLv3",
        "OP_NO_TLSv1",
    ]
    return not all(hasattr(ssl, attr) for attr in required_attributes)


_is_legacy_ssl: bool = _detect_legacy_ssl()


def _best_possible_default_version():
    global _is_legacy_ssl
    if _is_legacy_ssl:
        # Python < 2.7.9 does not expose OP_NO_SSLv2 and OP_NO_SSLv3. Depending
        # on what version of OpenSSL Python is linked against, SSLv23 *may*
        # be able to connect to TLS 1.2, but since we can't disable SSLv2
        # and SSLv3, we might as well default to TLS 1.0.
        return ssl.PROTOCOL_TLSv1

    # Newer versions of Python (>= 3.6.0) introduced PROTOCOL_TLS, which is
    # recommended against PROTOCOL_SSLv23, even though at this time they are
    # aliases for one another.
    return next(
        getattr(ssl, p) for p in ["PROTOCOL_TLS", "PROTOCOL_SSLv23"] if hasattr(ssl, p)
    )


if _is_legacy_ssl:
    _has_warned = False

    def _warn_if_legacy():
        global _has_warned
        if not _has_warned:
            logging.warning(
                "You are using an old version of Python (< 2.7.9) that is "
                "limited to an old version of TLS (1.0) with known security "
                "vulnerabilities. "
            )
            _has_warned = True

    def _warn_if_insecure_version_specified(version):
        pass

    def _get_ssl_socket(
        socket,
        ssl_version,
        cert_reqs=ssl.CERT_NONE,
        ca_certs=None,
        keyfile=None,
        certfile=None,
        **kwargs
    ):
        return ssl.SSLSocket(
            socket,
            ssl_version=ssl_version,
            cert_reqs=cert_reqs,
            ca_certs=ca_certs,
            keyfile=keyfile,
            certfile=certfile,
        )


else:

    def _warn_if_legacy():
        pass

    def _warn_if_insecure_version_specified(version):
        if version is None:
            return

        blacklist = [ssl.PROTOCOL_TLSv1]

        if hasattr(ssl, "PROTOCOL_SSLv2"):
            blacklist.append(ssl.PROTOCOL_SSLv2)
        if hasattr(ssl, "PROTOCOL_SSLv3"):
            blacklist.append(ssl.PROTOCOL_SSLv3)
        if hasattr(ssl, "PROTOCOL_TLSv1_1"):
            blacklist.append(ssl.PROTOCOL_TLSv1_1)

        if version in blacklist:
            logging.warning(
                "You are constructing TSSLSocket and intentionally specifying "
                "a weak, vulnerable ssl_version on a platform that has secure "
                "versions available! Leave ssl_version unspecified and we will "
                "automatically choose a suitable, secure version for you."
            )

    def _get_ssl_socket(
        socket,
        ssl_version,
        cert_reqs=ssl.CERT_NONE,
        ca_certs=None,
        keyfile=None,
        certfile=None,
        disable_weaker_versions=True,
    ):
        ctx = ssl.SSLContext(ssl_version)
        ctx.verify_mode = cert_reqs
        if certfile is not None:
            ctx.load_cert_chain(
                certfile=certfile,
                keyfile=keyfile,
            )

        if ca_certs is not None:
            ctx.load_verify_locations(
                cafile=ca_certs,
            )

        ctx.options |= ssl.OP_NO_SSLv2
        ctx.options |= ssl.OP_NO_SSLv3

        if disable_weaker_versions:
            ctx.options |= ssl.OP_NO_TLSv1

            # Python 2.7.9+ has this symbol, Python 3 only gets this at 3.4
            if hasattr(ssl, "OP_NO_TLSv1_1"):
                ctx.options |= ssl.OP_NO_TLSv1_1

        return ctx.wrap_socket(socket)


class TSSLSocket(TSocket):
    """Socket implementation that communicates over an SSL/TLS encrypted
    channel."""

    def __init__(
        self,
        host="localhost",
        port=9090,
        unix_socket=None,
        ssl_version=None,
        cert_reqs=ssl.CERT_NONE,
        ca_certs=None,
        verify_name=False,
        keyfile=None,
        certfile=None,
        allow_weak_ssl_versions=False,
    ):
        """Initialize a TSSLSocket.

        @param ssl_version(int)  protocol version. see ssl module. If none is
                                 specified, we will default to the most
                                 reasonably secure and compatible configuration
                                 if possible.

                                 For Python versions >= 2.7.9, we will default
                                 to at least TLS 1.1.

                                 For Python versions < 2.7.9, we can only
                                 default to TLS 1.0, which is the best that
                                 Python guarantees to offers at this version.
                                 If you specify ssl.PROTOCOL_SSLv23, and
                                 the OpenSSL linked with Python is new enough,
                                 it is possible for a TLS 1.2 connection be
                                 established; however, there is no way in
                                 < Python 2.7.9 to explicitly disable SSLv2
                                 and SSLv3. For that reason, we default to
                                 TLS 1.0.

        @param cert_reqs(int)    whether to verify peer certificate. see ssl
                                 module.
        @param ca_certs(str)     filename containing trusted root certs.
        @param verify_name       if False, no peer name validation is performed
                                 if True, verify subject name of peer vs 'host'
                                 if a str, verify subject name of peer vs given
                                 str
        @param keyfile           filename containing the client's private key
        @param certfile          filename containing the client's cert and
                                 optionally the private key

        @param allow_weak_ssl_versions(bool) By default, we try to disable older
                                             protocol versions. Only set this
                                             if you know what you are doing.
        """
        TSocket.__init__(self, host, port, unix_socket)
        self.cert_reqs = cert_reqs
        self.ca_certs = ca_certs
        self.ssl_version = ssl_version
        self.verify_name = verify_name
        self.client_keyfile = keyfile
        self.client_certfile = certfile
        self.allow_weak_ssl_versions = allow_weak_ssl_versions
        _warn_if_legacy()
        _warn_if_insecure_version_specified(ssl_version)

    def open(self):
        TSocket.open(self)
        try:
            ssl_version = (
                self.ssl_version
                if self.ssl_version is not None
                else _best_possible_default_version()
            )

            sslh = _get_ssl_socket(
                self.handle,
                ssl_version=ssl_version,
                cert_reqs=self.cert_reqs,
                ca_certs=self.ca_certs,
                keyfile=self.client_keyfile,
                certfile=self.client_certfile,
                disable_weaker_versions=not self.allow_weak_ssl_versions,
            )

            if self.verify_name:
                # validate the peer certificate commonName against the
                # hostname (or given name) that we were expecting.
                cert = sslh.getpeercert()
                str_type = (str, unicode) if sys.version_info[0] < 3 else str
                if isinstance(self.verify_name, str_type):
                    valid_names = self._getCertNames(cert)
                    name = self.verify_name
                else:
                    valid_names = self._getCertNames(cert, "DNS")
                    name = self.host
                match = False
                for valid_name in valid_names:
                    if self._matchName(name, valid_name):
                        match = True
                        break
                if not match:
                    sslh.close()
                    raise TTransportException(
                        TTransportException.NOT_OPEN,
                        "failed to verify certificate name",
                    )
            self.setHandle(sslh)
        except ssl.SSLError as e:
            raise TTransportException(
                TTransportException.NOT_OPEN, "SSL error during handshake: " + str(e)
            )
        except socket.error as e:
            raise TTransportException(
                TTransportException.NOT_OPEN,
                "socket error during SSL handshake: " + str(e),
            )

    @staticmethod
    def _getCertNames(cert, includeAlt=None):
        """Returns a set containing the common name(s) for the given cert. If
        includeAlt is not None, then X509v3 alternative names of type includeAlt
        (e.g. 'DNS', 'IPADD') will be included as potential matches."""
        # The "subject" field is a tuple containing the sequence of relative
        # distinguished names (RDNs) given in the certificate's data structure
        # for the principal, and each RDN is a sequence of name-value pairs.
        names = set()
        for rdn in cert.get("subject", ()):
            for k, v in rdn:
                if k == "commonName":
                    names.add(v)
        if includeAlt:
            for k, v in cert.get("subjectAltName", ()):
                if k == includeAlt:
                    names.add(v)
        return names

    @staticmethod
    def _matchName(name, pattern):
        """match a DNS name against a pattern. match is not case sensitive.
        a '*' in the pattern will match any single component of name."""
        name_parts = name.split(".")
        pattern_parts = pattern.split(".")
        if len(name_parts) != len(pattern_parts):
            return False
        for n, p in zip(name_parts, pattern_parts):
            if p != "*" and (n.lower() != p.lower()):
                return False
        return True


class TSSLServerSocket(TServerSocket):
    """
    SSL implementation of TServerSocket

    Note that this does not support TNonblockingServer

    This uses the ssl module's wrap_socket() method to provide SSL
    negotiated encryption.
    """

    def __init__(
        self,
        port=9090,
        ssl_version=ssl.PROTOCOL_TLSv1,
        cert_reqs=ssl.CERT_NONE,
        ca_certs=None,
        certfile="cert.pem",
        unix_socket=None,
    ):
        """Initialize a TSSLServerSocket

        @param certfile: The filename of the server certificate file, defaults
                         to cert.pem
        @type certfile: str
        @param port: The port to listen on for inbound connections.
        @type port: int
        """
        self.setCertfile(certfile)
        self.setCertReqs(cert_reqs, ca_certs)
        self.ssl_version = ssl_version
        TServerSocket.__init__(self, port, unix_socket)

    def setCertfile(self, certfile):
        """Set or change the server certificate file used to wrap new
        connections.

        @param certfile: The filename of the server certificate, i.e.
                         '/etc/certs/server.pem'
        @type certfile: str

        Raises an IOError exception if the certfile is not present or
        unreadable.
        """
        if not os.access(certfile, os.R_OK):
            raise IOError("No such certfile found: %s" % (certfile))
        self.certfile = certfile

    def setCertReqs(self, cert_reqs, ca_certs):
        """Set or change the parameters used to validate the client's
        certificate.  The parameters behave the same as the arguments to
        python's ssl.wrap_socket() method with the same name.
        """
        self.cert_reqs = cert_reqs
        self.ca_certs = ca_certs

    def accept(self):
        plain_client, addr = self._sock_accept()
        try:
            client = ssl.wrap_socket(
                plain_client,
                certfile=self.certfile,
                server_side=True,
                ssl_version=self.ssl_version,
                cert_reqs=self.cert_reqs,
                ca_certs=self.ca_certs,
            )
        except ssl.SSLError:
            # failed handshake/ssl wrap, close socket to client
            plain_client.close()
            # raise ssl_exc
            # We can't raise the exception, because it kills most TServer
            # derived serve() methods.
            # Instead, return None, and let the TServer instance deal with it
            # in other exception handling.  (but TSimpleServer dies anyway)
            print(traceback.print_exc())
            return None

        return self._makeTSocketFromAccepted((client, addr))
