thrift/lib/py/util/TValidator.py (114 lines of code) (raw):
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pyre-unsafe
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import logging
from thrift.Thrift import TType
_log = logging.getLogger("thrift.validator")
import sys
if sys.version_info[0] >= 3:
basestring = str
unicode = str
long = int
class TValidator:
tinfo = {
# ttype: (type_name, python_type, min_value, max_value)
TType.BOOL: ("BOOL", bool, None, None),
TType.BYTE: ("BYTE", int, -128, 127),
TType.DOUBLE: ("DOUBLE", float, None, None),
TType.I16: ("I16", int, -32768, 32767),
TType.I32: ("I32", int, -2147483648, 2147483647),
TType.I64: ("I64", (int, long), None, None),
TType.STRING: ("STRING", basestring, None, None),
TType.UTF8: ("UTF8", unicode, None, None),
}
def __init__(self):
self.custom_validators = {}
def addClassValidator(self, name, validator):
self.custom_validators[name] = validator
def validate(self, msg):
if not hasattr(msg, "thrift_spec"):
_log.error("Not a valid thrift object")
return False
name = msg.__class__.__name__
return self.check_struct(name, msg, msg.thrift_spec)
def check_basic(self, name, value, thrift_type):
if thrift_type not in self.tinfo:
_log.warn(
"%s Unrecognized thrift type %d. No validation done!", name, thrift_type
)
return True
t_name, python_type, v_min, v_max = self.tinfo[thrift_type]
if not isinstance(value, python_type):
error = "Value %s is not a %s" % (str(value), t_name)
elif (v_min is not None and value < v_min) or (
v_max is not None and value > v_max
):
error = "Value %s not within %s boundaries" % (str(value), t_name)
else:
error = None
if error is None:
_log.debug("%s -> %s (type: %s) OK", name, str(value), t_name)
else:
_log.error("ERROR: %s is WRONG. %s", name, error)
return error is None
def check_map(self, name, value, k_type, k_specs, v_type, v_specs):
_log.debug("%s - MAP check:", name)
ok = True
for k, v in value.items():
if not self.check_type("%s key" % (name), k, k_type, k_specs):
ok = False
if not self.check_type("%s[%s]" % (name, k), v, v_type, v_specs):
ok = False
return ok
def check_listset(self, name, value, v_type, v_specs):
_log.debug("%s - LIST/SET check:", name)
ok = True
for i, v in enumerate(value):
if not self.check_type("%s[%d]" % (name, i), v, v_type, v_specs):
ok = False
return ok
def check_type(self, name, value, v_type, specs):
if value is None:
_log.debug("%s - NOT set", name)
return True
if v_type == TType.STRUCT:
struct_specs = specs[1]
ok = self.check_struct(name, value, struct_specs)
elif v_type == TType.MAP:
k_type = specs[0]
k_specs = specs[1]
v_type = specs[2]
v_specs = specs[3]
ok = self.check_map(name, value, k_type, k_specs, v_type, v_specs)
elif v_type in (TType.LIST, TType.SET):
v_type = specs[0]
v_specs = specs[1]
ok = self.check_listset(name, value, v_type, v_specs)
else:
ok = self.check_basic(name, value, v_type)
return ok
def check_struct(self, name, value, specs):
_log.debug("%s - STRUCT check:", name)
if specs is None:
_log.error("%s - Empty thrift specs, can not be validated", name)
return False
ok = True
for spec in specs:
if spec is None:
# Some fields of the struct might be not defined or
# skipped old fields and their spec will be None
continue
f_name = name + "." + str(spec[2])
f_type = spec[1]
f_value = getattr(value, spec[2])
f_specs = spec[3]
if not self.check_type(f_name, f_value, f_type, f_specs):
ok = False
class_name = value.__class__.__name__
if ok and class_name in self.custom_validators:
if self.custom_validators[class_name](value):
_log.debug("%s - Custom validator for class %s OK", name, class_name)
else:
_log.error(
"%s - Custom validator for class %s failed!", name, class_name
)
ok = False
return ok