py_look_for_timeouts/main.py (191 lines of code) (raw):
#!/usr/bin/env python
import argparse
import ast
from . import __version__
class IllegalLine(object):
def __init__(self, reason, node, filename):
self.reason = reason
self.lineno = node.lineno
self.filename = filename
self.node = node
def __str__(self):
return "%s:%d\t%s" % (self.filename, self.lineno, self.reason)
def __repr__(self):
return "IllegalLine<%s, %s:%s>" % (
self.reason,
self.filename,
self.lineno
)
def _intify(something):
if isinstance(something, ast.Num):
return something.n
else:
# we aren't going to evaluate anything else, so, uh
# assume it was okay
return None
def _stringify(node):
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
return '%s.%s' % (_stringify(node.value), node.attr)
elif isinstance(node, ast.Subscript):
return '%s[%s]' % (_stringify(node.value), _stringify(node.slice))
elif isinstance(node, ast.Index):
return _stringify(node.value)
elif isinstance(node, ast.Call):
return '%s(%s, %s)' % (
_stringify(node.func),
_stringify(node.args),
_stringify(node.keywords)
)
elif isinstance(node, list):
return '[%s]' % (', '.join(_stringify(n) for n in node))
elif isinstance(node, ast.Str):
return node.s
else:
return ast.dump(node)
class Visitor(ast.NodeVisitor):
def __init__(self, filename, checker, *args, **kwargs):
self.filename = filename
self.checker = checker
self.errors = []
super(Visitor, self).__init__(*args, **kwargs)
@staticmethod
def _is_urlopen_call(function_name):
if '.' in function_name:
if function_name in ('urllib.urlopen', 'urllib2.urlopen'):
return True
else:
if function_name == 'urlopen':
return True
return False
@staticmethod
def _is_httplib_call(function_name):
if '.' in function_name:
if function_name in (
'httplib.HTTPConnection',
'httplib.HTTPSConnection'
):
return True
else:
if function_name in ('HTTPConnection', 'HTTPSConnection'):
return True
return False
@staticmethod
def _is_twilio_call(function_name):
if '.' in function_name:
if function_name.endswith('rest.TwilioRestClient'):
return True
elif function_name == 'TwilioRestClient':
return True
return False
@staticmethod
def _is_requests_call(function_name):
if function_name in (
'requests.get',
'requests.post',
'requests.put',
'requests.head',
'requests.request',
):
return True
return False
def _check_timeout_call(self, node, arg_offset, kwarg_name, desc):
# Grab the timeout node inside the function call
timeout = None
is_kwarg = False
if arg_offset is not None and len(node.args) > arg_offset:
timeout = node.args[arg_offset]
elif node.keywords:
keywords = [k for k in node.keywords if k.arg == kwarg_name]
if keywords:
is_kwarg = True
timeout = keywords[0].value
errors = self.checker(timeout, desc, node, self.filename, is_kwarg)
if errors:
self.errors.extend(errors)
def visit_Call(self, node):
function_name = _stringify(node.func)
if self._is_urlopen_call(function_name):
self._check_timeout_call(
node,
arg_offset=2,
kwarg_name='timeout',
desc='urlopen call'
)
elif self._is_httplib_call(function_name):
self._check_timeout_call(
node,
arg_offset=5,
kwarg_name='timeout',
desc='httplib connection'
)
elif self._is_twilio_call(function_name):
self._check_timeout_call(
node,
arg_offset=5,
kwarg_name='timeout',
desc='twilio rest connection'
)
elif self._is_requests_call(function_name):
self._check_timeout_call(
node,
arg_offset=None,
kwarg_name='timeout',
desc='requests call'
)
class Checker(object):
def __init__(self, allow_hardcoded=True):
self.allow_hardcoded = allow_hardcoded
def __call__(self, timeout_node, desc, node, filename, is_kwarg):
"""Return a list of IllegalLine on misconfigured timeout.
:param timeout_node:
:param desc:
:param node:
:param str filename:
"""
msg = None
if not timeout_node:
msg = '%s without a timeout arg or kwarg' % desc
return [IllegalLine(msg, node, filename)]
value = _intify(timeout_node)
if value == 0:
msg = '%s with a timeout %sarg of 0' % (
desc, 'kw' if is_kwarg else '')
elif isinstance(value, int) and not self.allow_hardcoded:
msg = '%s with an hardcoded timeout arg of %d' % (desc, value)
if msg:
return [IllegalLine(msg, node, filename)]
def check(filename, checker=None):
"""Check a file for missing/misconfigure timeouts."""
if not checker:
checker = Checker()
v = Visitor(filename, checker=checker)
with open(filename, 'r') as fobj:
try:
parsed = ast.parse(fobj.read(), filename)
v.visit(parsed)
except Exception: # noqa
raise # noqa
return v.errors
def main():
parser = argparse.ArgumentParser(
description='Look for python source files missing timeouts',
epilog=('Exit status is 0 if all files are okay, 1 if any files '
'have an error. Errors are printed to stdout')
)
parser.add_argument(
'--version',
action='version',
version='%(prog)s ' + __version__
)
parser.add_argument(
'--no-hardcoded',
action='store_true',
help="Do not allow hardcoded constant"
)
parser.add_argument('files', nargs='+', help='Files to check')
args = parser.parse_args()
errors = []
checker = Checker(allow_hardcoded=not args.no_hardcoded)
for fname in args.files:
these_errors = check(fname, checker=checker)
if these_errors:
print '\n'.join(str(e) for e in these_errors)
errors.extend(these_errors)
if errors:
print '%d total errors' % len(errors)
return 1
else:
return 0
if __name__ == '__main__':
main()