cacheck/ccadb/db.py (188 lines of code) (raw):

import psycopg2 import json import copy import itertools from collections.abc import Iterable from flask import current_app as app class CCADB(): def __init__(self): self.conn = self.connect( app.config['DB_HOST'], app.config['DB_PORT'], app.config['DB_USER'], app.config['DB_PASS'], app.config['DB_DATABASE'] ) def connect(self, host, port, user, passw, database): """ Connects CCADB instance to a Postgres CCADB :param host: String: The hostname of the db server :param port: Int: port number to connect on :param user: String: The username to connect with :param passw: String: The password to use :param database: String: The database name to use :return: None :rtype: None """ conn = psycopg2.connect("dbname={} host={} port={} user={} password={}".format( database, host, port, user, passw)) conn.set_session(readonly=True, autocommit=True) return conn def _query_db(self, query, params): """ Internal helper function, wraps all db queries Returns db cursor :param query: SQL query string :param params: iterable list of params for the query :return: A db cursor with the result :rtype: psycopg2.cursor """ if not isinstance(params, Iterable): raise RuntimeError("SQL query parameters needs to be iterable - {} passed".format(type(params))) cursor = self.conn.cursor() if app.config['DEBUG']: print(query) print(params) cursor.execute(query, params) return cursor def issuer_ca_id_from_digest(self, digest, fingerprint): """ Finds certificates based on a digest type and returns the issuing CA ID. :param fingerprint: Base64 encoded string that corresponds to the fingerprint of a certificate :return: The issuing CA ID of the certificate :rtype: int """ cursor = self._query_db("SELECT issuer_ca_id FROM certificate WHERE digest(certificate, %s) = decode(%s, 'hex')", (digest, fingerprint,)) r = cursor.fetchone() if r: issuer_ca_id = r[0] return str(r[0]), 200 return "-1", 400 def ca_id_from_digest(self, digest, fingerprint): """ Finds certificates based on a digest type and returns the CA ID. :param fingerprint: Base64 encoded string that corresponds to the fingerprint of a certificate :return: The issuing CA ID of the certificate :rtype: int """ cursor = self._query_db("""SELECT ca_certificate.ca_id FROM certificate LEFT JOIN ca_certificate ON certificate.id=ca_certificate.certificate_id WHERE digest(certificate, %s) = decode(%s, 'hex')""", (digest, fingerprint,)) r = cursor.fetchone() if r: issuer_ca_id = r[0] return str(r[0]), 200 return "-1", 400 def cert_info(self, certificate_id): """ Finds certificates based on a digest type and returns the issuing CA ID. :param certificate_id: (int) ID of the certificate to lookup :return: information about the certificate :rtype: dict """ if not isinstance(certificate_id, int): raise RuntimeError("Error! certificate_id needs to be an int, not {}".format(type(certificate_id))) ##WARNING: If any x509* functions fail, rows are not returned #x509_extensions(certificate), #x509_extkeyusages(certificate), #x509_getpathlenconstraint(certificate), #x509_altnames(certificate), keys = [ 'notbefore', 'notafter', 'subjectname', 'commonname', 'serialnumber', 'name', 'authoritykeyid', 'publickey', 'subjectkeyidentifier', 'issuername' ] # 'crldistributionpoints', 'authorityinfoaccess', # 'canissuecerts', 'certpolicies', # 'keyalgorithm', 'keysize' #] cert_funcs = map(lambda x: x[0] + x[1] + "(certificate)", zip(itertools.repeat("x509_"), keys)) cursor = self._query_db(""" SELECT issuer_ca_id, digest(certificate, 'sha256'), digest(certificate, 'sha1'), {} FROM certificate WHERE id=%s """.format(", \n\t\t".join(cert_funcs)), (certificate_id, )) r = cursor.fetchone() if r: a = {'issuer_ca_id': r[0], 'id': certificate_id, 'sha256_fingerprint': r[1], 'sha1_fingerprint': r[2] } a.update(dict(zip(keys, r[3:]))) for k, v in a.items(): if isinstance(v, memoryview): a[k] = v.hex() return a return "-1" def ca_id_from_cert_id(self, cert_id): """ Finds the issuing CA id for a corresponding certificate id :param cert_id: int certificate ID :return: The issuing CA ID :rtype: int """ cursor = self._query_db("""SELECT ca_id from certificate LEFT JOIN ca_certificate ON ca_certificate.certificate_id=certificate.id WHERE certificate.id=%s""", (cert_id,)) caid = cursor.fetchone()[0] cursor.close() return caid def cert_id_from_ca_id(self, ca_id): """ Finds the certificate id for the CA :param ca_id: int CA ID :return: The certificate id :rtype: int """ cursor = self._query_db('SELECT certificate_id from ca_certificate WHERE ca_id=%s', (ca_id,)) cert_id = cursor.fetchone()[0] cursor.close() return cert_id @staticmethod def _rec_get_keys(d): """ recursively get set of all keys in dictionary :param d: dict :return: set of all keys :rtype: set """ keys = set(map(lambda x: int(x), d.keys())) for v in d.values(): keys = keys.union( CCADB._rec_get_keys(v) ) return keys def build_ca_tree(self, parent_ca_id, depth): """ Build a tree of intermediate CAs :param parent_ca_id: int(CA ID) of root :param depth: maximum tree depth :return: A tree of CA IDs :rtype: dict """ skip_ca_ids = set([]) return self._rec_get_ca_children(parent_ca_id, skip_ca_ids, depth) def _rec_get_ca_children(self, ca_id, parent_ca_ids, depth): """ Recursively get CA children and build tree of dicts NB: This does not get any cert ids, only CAs :param ca_id: int(CA_ID) :param parent_ca_ids: A set of previously seen CA IDs. Ignore a child CA if it is already contained in the tree. """ ca_tree = {} if depth == 0: return ca_tree, {} elif depth == -1: pass elif depth > 0: depth = depth - 1 else: raise RuntimeError("Logic error in recursive CA tree builder. depth is < -1") ##get children ccas, ca_cn_map = self.get_child_ca_ids(ca_id) for cca_id in ccas: #skip child ca ids already included if cca_id in parent_ca_ids: continue if not isinstance(cca_id, type(None)): parent_ca_ids.add(cca_id) cca_tree, cca_cn_map = self._rec_get_ca_children(cca_id, parent_ca_ids, depth) ca_tree[cca_id] = cca_tree ca_cn_map.update(cca_cn_map) return ca_tree, ca_cn_map def pprint_ca_id(self, ca_id): """ Finds the issuing CA id for a corresponding certificate id :param cert_id: int certificate ID :return: The issuing CA ID :rtype: int """ cursor = self._query_db(""" SELECT x509_print(certificate.certificate) FROM certificate LEFT JOIN ca_certificate ON ca_certificate.certificate_id=certificate.id WHERE ca_id=%s """, (ca_id,)) return cursor.fetchone() def get_child_ca_ids(self, ca_id): """ Finds child CAs from the parent CA ID :param ca_id: int CA ID :return: A list of child CA IDs :rtype: set([int]) """ cca_ids = set() ca_cn_map = {} cursor = self._query_db( """ SELECT ca_id, x509_commonname(certificate.certificate) FROM certificate LEFT JOIN ca_certificate ON certificate_id=id WHERE issuer_ca_id=%s AND x509_canissuecerts(certificate.certificate)=True; """ , (ca_id,)) res = cursor.fetchall() for cca_id in res: if not isinstance(cca_id[0] , type(None)): cca_ids.add(cca_id[0]) ca_cn_map[cca_id[0]] = cca_id[1] #remove parent ca_id if int(ca_id) in cca_ids: cca_ids.remove(int(ca_id)) cursor.close() return cca_ids, ca_cn_map def lint_issues_for_ca_ids(self, ca_ids, daterange, cert_options, linters): """ Finds lint issues for a CA id :param fingerprint: Base64 encoded string that corresponds to the fingerprint of a certificate :return: The issuing CA ID of the certificate :rtype: int """ if True not in linters: raise RuntimeError("Error! No linters selected!") exclude_onecrl, exclude_expired_certs, exclude_revoked, exclude_technically_constrained = cert_options lint_names = [ "cablint", "zlint", "x509lint" ] query_linter = set(filter(lambda x: x[1], zip(lint_names, linters))) query_conf = '{' + ','.join(list(zip(*query_linter))[0]) + '}' caids_conf = '{' + ','.join(ca_ids) + '}' args = [caids_conf, query_conf] if app.config['DEBUG']: print("ca ids:", ca_ids, ", start: ", daterange[0], ", end:", daterange[1], ", linters: ", linters) print(query_linter) print(query_conf) sql_query = "SELECT lint_cert_issue.certificate_id, lint_issue.issue_text, lint_issue_id, certificate.issuer_ca_id, x509_notbefore(certificate.certificate), x509_notafter(certificate.certificate), issue_text, linter, severity, " sql_query += "x509_issuername(certificate.certificate), x509_subjectname(certificate.certificate), " sql_query += "encode(digest(certificate.certificate, 'sha256'), 'hex'), " sql_query += "EXISTS(SELECT certificate_id FROM google_revoked WHERE certificate_id=certificate.id), " sql_query += "EXISTS(SELECT certificate_id FROM mozilla_onecrl WHERE certificate_id=certificate.id), " sql_query += "EXISTS(SELECT certificate_id FROM microsoft_disallowedcert WHERE certificate_id=certificate.id) " sql_query += "FROM lint_cert_issue LEFT JOIN lint_issue " sql_query += "ON lint_cert_issue.lint_issue_id=lint_issue.id " ##left join on certificate to filter expired certificates sql_query += "LEFT JOIN certificate " sql_query += "ON lint_cert_issue.certificate_id=certificate.id " sql_query += "WHERE lint_cert_issue.issuer_ca_id = ANY(%s) " #sql_query += "AND linter = ANY('{cablint,x509lint,zlint}')" sql_query += "AND linter = ANY(%s) " if exclude_expired_certs: sql_query += "AND x509_notafter(certificate.certificate) > NOW() " if exclude_technically_constrained: #TODO: is_technically_constrained2? sql_query += "AND is_technically_constrained(certificate.certificate) = false " if daterange[0]: sql_query += "AND lint_cert_issue.not_before_date > %s " args.append(daterange[0].strftime('%Y-%m-%d')) if daterange[1]: sql_query += "AND lint_cert_issue.not_before_date < %s " args.append(daterange[1].strftime('%Y-%m-%d')) if exclude_onecrl: sql_query += "AND lint_cert_issue.certificate_id NOT IN ( " sql_query += "SELECT certificate_id FROM mozilla_onecrl WHERE certificate_id IS NOT NULL ) " cursor = self._query_db(sql_query, args) res = cursor.fetchall() lint_issues = [] for r in res: print(r) fields = [ 'certificate_id', 'issue_text', 'lint_issue_id', 'issuer_ca_id', 'not_before_date', 'not_after_date', 'issue_text', 'linter', 'severity', 'issuer_cn', 'subject_cn', 'sha256_fingerprint', 'google_revoked', 'onecrl_revoked', 'microsoft_revoked' ] lint_issue = dict(zip(fields, r)) #if not isinstance(lint_issue['revocation_status'], str): # lint_issue['revocation_status'] = 'Not Revoked' if exclude_revoked: revoked = False for k in [ 'onecrl_revoked', 'microsoft_revoked', 'google_revoked' ]: if lint_issue[k]: revoked = True break if not revoked: lint_issues.append(lint_issue) else: lint_issues.append(lint_issue) cursor.close() return lint_issues