def edit_config()

in easy_rec/python/utils/config_util.py [0:0]


def edit_config(pipeline_config, edit_config_json):
  """Update params specified by automl.

  Args:
    pipeline_config: EasyRecConfig
    edit_config_json: edit config json
  """

  def _type_convert(proto, val, parent=None):
    if type(val) != type(proto):
      try:
        if isinstance(proto, bool):
          assert val in ['True', 'true', 'False', 'false']
          val = val in ['True', 'true']
        else:
          val = type(proto)(val)
      except ValueError as ex:
        if parent is None:
          raise ex
        assert isinstance(proto, int)
        val = getattr(parent, val)
        assert isinstance(val, int)
    return val

  def _get_attr(obj, attr, only_last=False):
    # only_last means we only return the last element in paths array
    attr_toks = [x.strip() for x in attr.split('.') if x != '']
    paths = []
    objs = [obj]
    nobjs = []
    for key in attr_toks:
      # clear old paths to clear new paths
      paths = []
      for obj in objs:
        if '[' in key:
          pos = key.find('[')
          name, cond = key[:pos], key[pos + 1:]
          cond = cond[:-1]
          update_objs = getattr(obj, name)
          # select all update_objs
          if cond == ':':
            for tid, update_obj in enumerate(update_objs):
              paths.append((obj, update_obj, None, tid))
              nobjs.append(update_obj)
            continue

          # select by range update_objs[1:10]
          if ':' in cond:
            colon_pos = cond.find(':')
            sid = cond[:colon_pos]
            if len(sid) == 0:
              sid = 0
            else:
              sid = int(sid)
            eid = cond[(colon_pos + 1):]
            if len(eid) == 0:
              eid = len(update_objs)
            else:
              eid = int(eid)
            for tid, update_obj in enumerate(update_objs[sid:eid]):
              paths.append((obj, update_obj, None, tid + sid))
              nobjs.append(update_obj)
            continue

          # for simple index update_objs[0]
          try:
            obj_id = int(cond)
            obj = update_objs[obj_id]
            paths.append((obj, update_objs, None, obj_id))
            nobjs.append(obj)
            continue
          except ValueError:
            pass

          # for complex conditions a[optimizer.lr=20]
          op_func_map = {
              '>=': lambda x, y: x >= y,
              '<=': lambda x, y: x <= y,
              '<': lambda x, y: x < y,
              '>': lambda x, y: x > y,
              '=': lambda x, y: x == y
          }
          cond_key = None
          cond_val = None
          op_func = None
          for op in ['>=', '<=', '>', '<', '=']:
            tmp_pos = cond.rfind(op)
            if tmp_pos != -1:
              cond_key = cond[:tmp_pos]
              cond_val = cond[(tmp_pos + len(op)):]
              op_func = op_func_map[op]
              break

          assert cond_key is not None, 'invalid cond: %s' % cond
          assert cond_val is not None, 'invalid cond: %s' % cond

          for tid, update_obj in enumerate(update_objs):
            tmp, tmp_parent, _, _ = _get_attr(
                update_obj, cond_key, only_last=True)

            cond_val = _type_convert(tmp, cond_val, tmp_parent)

            if op_func(tmp, cond_val):
              obj_id = tid
              paths.append((update_obj, update_objs, None, obj_id))
              nobjs.append(update_obj)
        else:
          sub_obj = getattr(obj, key)
          paths.append((sub_obj, obj, key, -1))
          nobjs.append(sub_obj)
      # exchange to prepare for parsing next token
      objs = nobjs
      nobjs = []
    if only_last:
      return paths[-1]
    else:
      return paths

  for param_keys in edit_config_json:
    # multiple keys/vals combination
    param_vals = edit_config_json[param_keys]
    param_vals = [x.strip() for x in str(param_vals).split(';')]
    param_keys = [x.strip() for x in str(param_keys).split(';')]
    for param_key, param_val in zip(param_keys, param_vals):
      update_obj = pipeline_config
      tmp_paths = _get_attr(update_obj, param_key)
      # update a set of objs
      for tmp_val, tmp_obj, tmp_name, tmp_id in tmp_paths:
        # list and dict are not basic types, must be handle separately
        basic_types = _get_basic_types()
        if type(tmp_val) in basic_types:
          # simple type cast
          tmp_val = _type_convert(tmp_val, param_val, tmp_obj)
          if tmp_name is None:
            tmp_obj[tmp_id] = tmp_val
          else:
            setattr(tmp_obj, tmp_name, tmp_val)
        elif 'Scalar' in str(type(tmp_val)) and 'ClearField' in dir(tmp_obj):
          tmp_obj.ClearField(tmp_name)
          text_format.Parse('%s:%s' % (tmp_name, param_val), tmp_obj)
        else:
          tmp_val.Clear()
          param_val = param_val.strip()
          if param_val.startswith('{') and param_val.endswith('}'):
            param_val = param_val[1:-1]
          text_format.Parse(param_val, tmp_val)

  return pipeline_config