in easy_rec/python/utils/config_util.py [0:0]
def edit_config(pipeline_config, edit_config_json):
"""Update params specified by automl.
Args:
pipeline_config: EasyRecConfig
edit_config_json: edit config json
"""
def _type_convert(proto, val, parent=None):
if type(val) != type(proto):
try:
if isinstance(proto, bool):
assert val in ['True', 'true', 'False', 'false']
val = val in ['True', 'true']
else:
val = type(proto)(val)
except ValueError as ex:
if parent is None:
raise ex
assert isinstance(proto, int)
val = getattr(parent, val)
assert isinstance(val, int)
return val
def _get_attr(obj, attr, only_last=False):
# only_last means we only return the last element in paths array
attr_toks = [x.strip() for x in attr.split('.') if x != '']
paths = []
objs = [obj]
nobjs = []
for key in attr_toks:
# clear old paths to clear new paths
paths = []
for obj in objs:
if '[' in key:
pos = key.find('[')
name, cond = key[:pos], key[pos + 1:]
cond = cond[:-1]
update_objs = getattr(obj, name)
# select all update_objs
if cond == ':':
for tid, update_obj in enumerate(update_objs):
paths.append((obj, update_obj, None, tid))
nobjs.append(update_obj)
continue
# select by range update_objs[1:10]
if ':' in cond:
colon_pos = cond.find(':')
sid = cond[:colon_pos]
if len(sid) == 0:
sid = 0
else:
sid = int(sid)
eid = cond[(colon_pos + 1):]
if len(eid) == 0:
eid = len(update_objs)
else:
eid = int(eid)
for tid, update_obj in enumerate(update_objs[sid:eid]):
paths.append((obj, update_obj, None, tid + sid))
nobjs.append(update_obj)
continue
# for simple index update_objs[0]
try:
obj_id = int(cond)
obj = update_objs[obj_id]
paths.append((obj, update_objs, None, obj_id))
nobjs.append(obj)
continue
except ValueError:
pass
# for complex conditions a[optimizer.lr=20]
op_func_map = {
'>=': lambda x, y: x >= y,
'<=': lambda x, y: x <= y,
'<': lambda x, y: x < y,
'>': lambda x, y: x > y,
'=': lambda x, y: x == y
}
cond_key = None
cond_val = None
op_func = None
for op in ['>=', '<=', '>', '<', '=']:
tmp_pos = cond.rfind(op)
if tmp_pos != -1:
cond_key = cond[:tmp_pos]
cond_val = cond[(tmp_pos + len(op)):]
op_func = op_func_map[op]
break
assert cond_key is not None, 'invalid cond: %s' % cond
assert cond_val is not None, 'invalid cond: %s' % cond
for tid, update_obj in enumerate(update_objs):
tmp, tmp_parent, _, _ = _get_attr(
update_obj, cond_key, only_last=True)
cond_val = _type_convert(tmp, cond_val, tmp_parent)
if op_func(tmp, cond_val):
obj_id = tid
paths.append((update_obj, update_objs, None, obj_id))
nobjs.append(update_obj)
else:
sub_obj = getattr(obj, key)
paths.append((sub_obj, obj, key, -1))
nobjs.append(sub_obj)
# exchange to prepare for parsing next token
objs = nobjs
nobjs = []
if only_last:
return paths[-1]
else:
return paths
for param_keys in edit_config_json:
# multiple keys/vals combination
param_vals = edit_config_json[param_keys]
param_vals = [x.strip() for x in str(param_vals).split(';')]
param_keys = [x.strip() for x in str(param_keys).split(';')]
for param_key, param_val in zip(param_keys, param_vals):
update_obj = pipeline_config
tmp_paths = _get_attr(update_obj, param_key)
# update a set of objs
for tmp_val, tmp_obj, tmp_name, tmp_id in tmp_paths:
# list and dict are not basic types, must be handle separately
basic_types = _get_basic_types()
if type(tmp_val) in basic_types:
# simple type cast
tmp_val = _type_convert(tmp_val, param_val, tmp_obj)
if tmp_name is None:
tmp_obj[tmp_id] = tmp_val
else:
setattr(tmp_obj, tmp_name, tmp_val)
elif 'Scalar' in str(type(tmp_val)) and 'ClearField' in dir(tmp_obj):
tmp_obj.ClearField(tmp_name)
text_format.Parse('%s:%s' % (tmp_name, param_val), tmp_obj)
else:
tmp_val.Clear()
param_val = param_val.strip()
if param_val.startswith('{') and param_val.endswith('}'):
param_val = param_val[1:-1]
text_format.Parse(param_val, tmp_val)
return pipeline_config