easy_rec/python/tools/convert_config_format.py (36 lines of code) (raw):

# -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import os from google.protobuf import json_format from google.protobuf import text_format from easy_rec.python.protos.pipeline_pb2 import EasyRecConfig def load_config(input_config): pipeline_config = EasyRecConfig() with open(input_config, 'r') as fin: tmp_str = fin.read() if input_config.endswith('.config'): text_format.Merge(tmp_str, pipeline_config) elif input_config.endswith('.json'): json_format.Parse(tmp_str, pipeline_config) else: assert False, 'only .config/.json are supported(%s)' % input_config return pipeline_config def save_config(pipeline_config, save_path): with open(save_path, 'w') as fout: if save_path.endswith('.config'): fout.write(text_format.MessageToString(pipeline_config, as_utf8=True)) elif save_path.endswith('.json'): fout.write( json_format.MessageToJson( pipeline_config, preserving_proto_field_name=True)) else: assert False, 'only .config/.json are supported(%s)' % save_path if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument( '--input_config', type=str, help='input_config path', default=None) parser.add_argument( '--output_config', type=str, help='output_config path', default=None) args = parser.parse_args() assert os.path.exists(args.input_config) pipeline_config = load_config(args.input_config) save_config(pipeline_config, args.output_config)