tsqa/endpoint.py (298 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import threading import requests import flask import socket import SocketServer import ssl from collections import defaultdict from wsgiref.simple_server import make_server # dict of testid -> {client_request, client_response} REQUESTS = defaultdict(dict) # TODO: some request/response class to load the various libary's implementations and allow for comparison class TrackingRequests(): ''' This class gives you a "requests" like object that will return a dict of: - client_request - client_response - server_request - server_response assuming the request is going to the instance of DynamicHTTPEndpoint this object was created with In general this is useful for a proxy testing framework beause you commonly need to check that the proxy (for example) added a header to the request before the origin got it. ''' def __init__(self, endpoint): self.endpoint = endpoint def __getattr__(self, name): def handlerFunction(*args,**kwargs): func = getattr(requests, name) # set some kwargs # set the tracking header if 'headers' not in kwargs: kwargs['headers'] = {} key = self.endpoint.get_tracking_key() kwargs['headers'][self.endpoint.TRACKING_HEADER] = key ret = {} resp = func(*args, **kwargs) server_resp = self.endpoint.get_tracking_by_key(key) # TODO: create intermediate objects that you can compare ret['client_request'] = resp.request ret['client_response'] = resp ret['server_request'] = server_resp['request'] ret['server_response'] = server_resp['response'] return ret return handlerFunction class DynamicHTTPEndpoint(threading.Thread): ''' A threaded webserver which allows you to dynamically add/remove handlers. This is implemented using flask (http://flask.pocoo.org/) primarily because it is very common and (almost more importantly) *very* picky about http semantics. To use this in a TestCase you simply need to create the thread: # create the thread object http_endpoint = tsqa.endpoint.DynamicHTTPEndpoint(port=cls.endpoint_port) # start the thread http_endpoint.start() # wait for the webserver to listen http_endpoint.ready.wait() At this point the webserver is listening and returning 404 for all requests. To register an endpoint you must (1) define a request-handler function and (2) add that handler to the http_endpoint. (1): To define a request handler you must create a function which takes a single argument which is the Request wrapper (http://werkzeug.pocoo.org/docs/0.10/wrappers/#werkzeug.wrappers.Request). Flask support a variety or return types (http://flask.pocoo.org/docs/0.10/quickstart/#about-responses), for this example we will simply return "hello world" def handler_func(request): return "hello world" (2): Now that we have a function, we can add it as a handler to a context path http_endpoint.add_handler('/hello', handler_func) ''' TRACKING_HEADER = '__cool_test_header__' # TODO: better name? @property def address(self): ''' Return a tuple of (ip, port) that this thread is listening on. ''' return (self.server.server_address, self.server.server_port) def __init__(self, port=0): threading.Thread.__init__(self) # dict to store request data in self._tracked_requests = {} # error in startup self.error = None self.daemon = True self.port = port self.ready = threading.Event() # dict of pathname (no starting /) -> function self._handlers = {} self.app = flask.Flask(__name__) self.app.debug = True @self.app.before_request def save_request(): ''' If the tracking header is set, save the request ''' if flask.request.headers.get(self.TRACKING_HEADER): self._tracked_requests[flask.request.headers[self.TRACKING_HEADER]] = {'request': flask.request} @self.app.after_request def save_response(response): ''' If the tracking header is set, save the response ''' if flask.request.headers.get(self.TRACKING_HEADER): self._tracked_requests[flask.request.headers[self.TRACKING_HEADER]]['response'] = response return response @self.app.route('/', defaults={'path': ''}) @self.app.route('/<path:path>') def catch_all(path=''): # get path key if path in self._handlers: return self._handlers[path](flask.request) # return a 404 since we didn't find it return ('', 404) # A little magic to make flask accept *all* methods on the catch_all path for rule in self.app.url_map.iter_rules(): rule.methods = None rule.refresh() def get_tracking_key(self): ''' Return a new key for tracking a request by key ''' key = str(len(self._tracked_requests)) self._tracked_requests[key] = {} return key def get_tracking_by_key(self, key): ''' Return tracking data by key ''' if key not in self._tracked_requests: raise Exception() return self._tracked_requests[key] def normalize_path(self, path): ''' Normalize the path, since its common (and convenient) to start with / in your paths ''' if path.startswith('/'): return path[1:] return path def add_handler(self, path, func): ''' Add a new handler attached to a specific path ''' path = self.normalize_path(path) if path in self._handlers: raise Exception() self._handlers[path] = func def remove_handler(self, path): ''' remove a handler attached to a specific path ''' path = self.normalize_path(path) if path not in self._handlers: raise Exception() del self._handlers[path] def clear_handlers(self): ''' Clear all handlers that have been registered ''' self._handlers = {} def url(self, path=''): ''' Get the url for the given path in this endpoint ''' if path and not path.startswith('/'): path = '/' + path return 'http://127.0.0.1:{0}{1}'.format(self.address[1], path) def run(self): try: self.server = make_server('', self.port, self.app.wsgi_app) # mark the socket as SO_REUSEADDR self.server.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) except Exception as e: self.error = e self.ready.set() return # mark it as ready self.ready.set() # serve it self.server.serve_forever() class TrackingWSGIServer(threading.Thread): ''' A threaded webserver which will wrap any wsgi app and track request/response headers to the origin # create the thread object http_endpoint = tsqa.endpoint.TrackingWSGIServer(app) # start the thread http_endpoint.start() # wait for the webserver to listen http_endpoint.ready.wait() ''' TRACKING_HEADER = '__cool_test_header__' # TODO: better name? @property def address(self): ''' Return a tuple of (ip, port) that this thread is listening on. ''' return (self.server.server_address, self.server.server_port) def __init__(self, app, port=0): threading.Thread.__init__(self) # dict to store request data in self._tracked_requests = {} self.daemon = True self.port = port self.ready = threading.Event() self.app = app self.app.debug = True @self.app.before_request def save_request(): ''' If the tracking header is set, save the request ''' if flask.request.headers.get(self.TRACKING_HEADER): self._tracked_requests[flask.request.headers[self.TRACKING_HEADER]] = {'request': request.copy()} @self.app.after_request def save_response(response): ''' If the tracking header is set, save the response ''' if flask.request.headers.get(self.TRACKING_HEADER): self._tracked_requests[flask.request.headers[self.TRACKING_HEADER]]['response'] = response return response def get_tracking_key(self): ''' Return a new key for tracking a request by key ''' key = str(len(self._tracked_requests)) self._tracked_requests[key] = {} return key def get_tracking_by_key(self, key): ''' Return tracking data by key ''' if key not in self._tracked_requests: raise Exception() return self._tracked_requests[key] def run(self): self.server = make_server('', self.port, self.app.wsgi_app) # mark it as ready self.ready.set() # serve it self.server.serve_forever() class ThreadedTCPServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer): pass class SocketServerDaemon(threading.Thread): ''' A daemon thread to run a socketserver ''' def __init__(self, handler, port=0): threading.Thread.__init__(self) self.port = port self.handler = handler self.ready = threading.Event() self.daemon = True def run(self): self.server = ThreadedTCPServer(('0.0.0.0', self.port), self.handler) self.server.allow_reuse_address = True self.port = self.server.socket.getsockname()[1] self.ready.set() # Activate the server; this will keep running until you # interrupt the program with Ctrl-C self.server.serve_forever() class ThreadedSSLTCPServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer): def __init__(self, server_address, RequestHandlerClass, certfile, keyfile, ssl_version=ssl.PROTOCOL_TLSv1, bind_and_activate=True): SocketServer.TCPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate) self.certfile = certfile self.keyfile = keyfile self.ssl_version = ssl_version def get_request(self): newsocket, fromaddr = self.socket.accept() connstream = ssl.wrap_socket(newsocket, server_side=True, certfile=self.certfile, keyfile=self.keyfile, ssl_version=self.ssl_version, ) return connstream, fromaddr class SSLSocketServerDaemon(threading.Thread): ''' A daemon thread to run a socketserver This is just a thread wrapper to https://docs.python.org/2/library/socketserver.html ''' def __init__(self, handler, cert, key, port=0): ''' handler: instance of SocketServer.BaseRequestHandler https://docs.python.org/2/library/socketserver.html#socketserver-tcpserver-example cert: path to certificate file key: path to key file ''' # for testing it is *very* common to have self-signed certs, so we # will disable warnings so we don't flood logs requests.packages.urllib3.disable_warnings() threading.Thread.__init__(self) self.handler = handler self.cert = cert self.key = key self.port = port self.ready = threading.Event() self.daemon = True def run(self): self.server = ThreadedSSLTCPServer(('0.0.0.0', self.port), self.handler, self.cert, self.key, ) self.server.allow_reuse_address = True self.port = self.server.socket.getsockname()[1] self.ready.set() self.server.serve_forever()