src/ab/plugins/calllimit/core.py (62 lines of code) (raw):

from flask import request, Response from ab import app from ab.utils import logger from ab.utils.exceptions import AlgorithmException from ab.plugins.cache.redis import cache_plugin # 不限制call次数 CALL_UNTRALIMIT = -1 def get_limit_key(key): return "limit:{}".format(key) def legal(ak, sk): """ 验证ak和sk是否合法 :param ak: :param sk: :return: """ if ak is None or sk is None: return False cache_client = cache_plugin.get_cache_client() if str(cache_client.get(ak).strip(), 'UTF-8') == sk.strip(): return True return False def get_call_count(key): """ 获得调用次数 :param key: :return: """ cache_client = cache_plugin.get_cache_client() ret = cache_client.iget(key) return 0 if not ret else ret def get_call_limit(key): """ 获得调用限制次数 :param key: :return: -1 不允许调用 """ cache_client = cache_plugin.get_cache_client() ret = cache_client.iget(get_limit_key(key)) return CALL_UNTRALIMIT if not ret else ret def inc_call_count(key): """ 增加调用次数 :param key: :return: """ current_count = get_call_count(key) current_count = current_count + 1 cache_client = cache_plugin.get_cache_client() cache_client.set(key, current_count) @app.before_request def before_call(): path = request.path request.limit = False if request.path.startswith("/api/algorithm"): ak = request.args.get("ak") if legal(ak, request.args.get("sk")): request.ak = ak key = "{}:{}".format(ak, path) current_count = get_call_count(key) call_limit = get_call_limit(key) if call_limit != CALL_UNTRALIMIT: request.limit = True if not current_count < get_call_limit(key): raise AlgorithmException(data="the API exceed the call limit : {}".format(key)) else: raise AlgorithmException(data="wrong ak or sk") @app.after_request def post_call(response): path = request.path if request.limit: key = "{}:{}".format(request.ak, path) inc_call_count(key) request.ak = None return response class CallLimit: def __init__(self): self.platform = None def set_platform(self, platform): self.platform = platform def start(self, config): logger.info("[plugin] CallLimit start") def stop(self): logger.info("[plugin] CallLimit stop")