pyignite/connection/ssl.py (46 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ssl from ssl import SSLContext from pyignite.constants import SSL_DEFAULT_CIPHERS, SSL_DEFAULT_VERSION from pyignite.exceptions import ParameterError def wrap(socket, ssl_params): """ Wrap socket in SSL wrapper. """ if not ssl_params.get('use_ssl'): return socket context = create_ssl_context(ssl_params) return context.wrap_socket(sock=socket) def check_ssl_params(params): expected_args = [ 'use_ssl', 'ssl_version', 'ssl_ciphers', 'ssl_cert_reqs', 'ssl_keyfile', 'ssl_keyfile_password', 'ssl_certfile', 'ssl_ca_certfile', ] for param in params: if param not in expected_args: raise ParameterError(( 'Unexpected parameter for connection initialization: `{}`' ).format(param)) def create_ssl_context(ssl_params): if not ssl_params.get('use_ssl'): return None keyfile = ssl_params.get('ssl_keyfile', None) certfile = ssl_params.get('ssl_certfile', None) if keyfile and not certfile: raise ValueError("certfile must be specified") password = ssl_params.get('ssl_keyfile_password', None) ssl_version = ssl_params.get('ssl_version', SSL_DEFAULT_VERSION) ciphers = ssl_params.get('ssl_ciphers', SSL_DEFAULT_CIPHERS) cert_reqs = ssl_params.get('ssl_cert_reqs', ssl.CERT_NONE) ca_certs = ssl_params.get('ssl_ca_certfile', None) context = SSLContext(ssl_version) context.verify_mode = cert_reqs if ca_certs: context.load_verify_locations(ca_certs) if certfile: context.load_cert_chain(certfile, keyfile, password) if ciphers: context.set_ciphers(ciphers) return context