def export_big_model_to_oss()

in easy_rec/python/utils/export_big_model.py [0:0]


def export_big_model_to_oss(export_dir, pipeline_config, oss_params,
                            serving_input_fn, estimator, checkpoint_path,
                            verbose):
  for key in oss_params:
    logging.info('%s: %s' % (key, oss_params[key]))

  write_kv_lib_path = os.path.join(easy_rec.ops_dir, 'libembed_op.so')
  kv_module = tf.load_op_library(write_kv_lib_path)

  if not checkpoint_path:
    checkpoint_path = estimator_utils.latest_checkpoint(
        pipeline_config.model_dir)
  logging.info('checkpoint_path = %s' % checkpoint_path)

  server = None
  cluster = None
  if 'TF_CONFIG' in os.environ:
    # change chief to master
    tf_config = estimator_utils.chief_to_master()
    if tf_config['task']['type'] == 'ps':
      cluster = tf.train.ClusterSpec(tf_config['cluster'])
      server = tf.train.Server(
          cluster, job_name='ps', task_index=tf_config['task']['index'])
      server.join()
    elif tf_config['task']['type'] == 'master':
      if 'ps' in tf_config['cluster']:
        cluster = tf.train.ClusterSpec(tf_config['cluster'])
        server = tf.train.Server(cluster, job_name='master', task_index=0)
        server_target = server.target
        logging.info('server_target = %s' % server_target)

  serving_input = serving_input_fn()
  features = serving_input.features
  inputs = serving_input.receiver_tensors

  if cluster:
    logging.info('cluster = ' + str(cluster))
  with tf.device(
      replica_device_setter(
          worker_device='/job:master/task:0', cluster=cluster)):
    outputs = estimator._export_model_fn(features, None, None,
                                         estimator.params).predictions

  meta_graph_def = export_meta_graph()
  meta_graph_def.meta_info_def.meta_graph_version = str(int(time.time()))
  oss_embedding_version = oss_params.get('oss_embedding_version', '')
  if not oss_embedding_version:
    meta_graph_def.meta_info_def.meta_graph_version =\
        str(int(time.time()))
  else:
    meta_graph_def.meta_info_def.meta_graph_version = oss_embedding_version

  logging.info('meta_graph_version = %s' %
               meta_graph_def.meta_info_def.meta_graph_version)

  embed_var_parts = {}
  embed_norm_name = {}
  embed_spos = {}
  # pai embedding variable
  embedding_vars = {}
  norm_name_to_ids = {}
  for x in global_variables():
    tf.logging.info('global var: %s %s %s' % (x.name, str(type(x)), x.device))
    if 'EmbeddingVariable' in str(type(x)):
      norm_name, part_id = proto_util.get_norm_embed_name(x.name)
      norm_name_to_ids[norm_name] = 1
      tmp_export = x.export()
      if x.device not in embedding_vars:
        embedding_vars[x.device] = [(norm_name, tmp_export.keys,
                                     tmp_export.values, part_id)]
      else:
        embedding_vars[x.device].append(
            (norm_name, tmp_export.keys, tmp_export.values, part_id))
    elif '/embedding_weights:' in x.name or '/embedding_weights/part_' in x.name:
      norm_name, part_id = proto_util.get_norm_embed_name(x.name)
      norm_name_to_ids[norm_name] = 1
      embed_norm_name[x] = norm_name
      if norm_name not in embed_var_parts:
        embed_var_parts[norm_name] = {part_id: x}
      else:
        embed_var_parts[norm_name][part_id] = x

  for tid, t in enumerate(norm_name_to_ids.keys()):
    norm_name_to_ids[t] = str(tid)

  for x in embed_norm_name:
    embed_norm_name[x] = norm_name_to_ids[embed_norm_name[x]]

  total_num = 0
  for norm_name in embed_var_parts:
    parts = embed_var_parts[norm_name]
    spos = 0
    part_ids = list(parts.keys())
    part_ids.sort()
    total_num += len(part_ids)
    for part_id in part_ids:
      embed_spos[parts[part_id]] = spos
      spos += parts[part_id].get_shape()[0]

  oss_path = oss_params.get('oss_path', '')
  oss_endpoint = oss_params.get('oss_endpoint', '')
  oss_ak = oss_params.get('oss_ak', '')
  oss_sk = oss_params.get('oss_sk', '')
  logging.info('will export to oss: %s %s %s %s', oss_path, oss_endpoint,
               oss_ak, oss_sk)

  if oss_params.get('oss_write_kv', ''):
    # group embed by devices
    per_device_vars = {}
    for x in embed_norm_name:
      if x.device not in per_device_vars:
        per_device_vars[x.device] = [x]
      else:
        per_device_vars[x.device].append(x)

    all_write_res = []
    for tmp_dev in per_device_vars:
      tmp_vars = per_device_vars[tmp_dev]
      with tf.device(tmp_dev):
        tmp_names = [embed_norm_name[v] for v in tmp_vars]
        tmp_spos = [np.array(embed_spos[v], dtype=np.int64) for v in tmp_vars]
        write_kv_res = kv_module.oss_write_kv(
            tmp_names,
            tmp_vars,
            tmp_spos,
            osspath=oss_path,
            endpoint=oss_endpoint,
            ak=oss_ak,
            sk=oss_sk,
            threads=oss_params.get('oss_threads', 5),
            timeout=5,
            expire=5,
            verbose=verbose)
        all_write_res.append(write_kv_res)

    for tmp_dev in embedding_vars:
      with tf.device(tmp_dev):
        tmp_vs = embedding_vars[tmp_dev]
        tmp_sparse_names = [norm_name_to_ids[x[0]] for x in tmp_vs]
        tmp_sparse_keys = [x[1] for x in tmp_vs]
        tmp_sparse_vals = [x[2] for x in tmp_vs]
        tmp_part_ids = [x[3] for x in tmp_vs]
        write_sparse_kv_res = kv_module.oss_write_sparse_kv(
            tmp_sparse_names,
            tmp_sparse_vals,
            tmp_sparse_keys,
            tmp_part_ids,
            osspath=oss_path,
            endpoint=oss_endpoint,
            ak=oss_ak,
            sk=oss_sk,
            version=meta_graph_def.meta_info_def.meta_graph_version,
            threads=oss_params.get('oss_threads', 5),
            verbose=verbose)
        all_write_res.append(write_sparse_kv_res)

    session_config = ConfigProto(
        allow_soft_placement=True, log_device_placement=False)
    chief_sess_creator = ChiefSessionCreator(
        master=server.target if server else '',
        checkpoint_filename_with_path=checkpoint_path,
        config=session_config)
    with tf.train.MonitoredSession(
        session_creator=chief_sess_creator,
        hooks=None,
        stop_grace_period_secs=120) as sess:
      dump_flags = sess.run(all_write_res)
      logging.info('write embedding to oss succeed: %s' % str(dump_flags))
  else:
    logging.info('will skip write embedding to oss because '
                 'oss_write_kv is set to 0.')

  # delete embedding_weights collections so that it could be re imported
  tmp_drop = []
  for k in meta_graph_def.collection_def:
    v = meta_graph_def.collection_def[k]
    if len(
        v.node_list.value) > 0 and 'embedding_weights' in v.node_list.value[0]:
      tmp_drop.append(k)
  for k in tmp_drop:
    meta_graph_def.collection_def.pop(k)

  meta_graph_editor = MetaGraphEditor(
      os.path.join(easy_rec.ops_dir, 'libembed_op.so'),
      None,
      oss_path=oss_path,
      oss_endpoint=oss_endpoint,
      oss_ak=oss_ak,
      oss_sk=oss_sk,
      oss_timeout=oss_params.get('oss_timeout', 1500),
      meta_graph_def=meta_graph_def,
      norm_name_to_ids=norm_name_to_ids,
      incr_update_params=oss_params.get('incr_update', None),
      debug_dir=export_dir if verbose else '')
  meta_graph_editor.edit_graph_for_oss()
  tf.reset_default_graph()

  saver = tf.train.import_meta_graph(meta_graph_editor._meta_graph_def)
  graph = tf.get_default_graph()

  embed_name_to_id_file = os.path.join(export_dir, 'embed_name_to_ids.txt')
  with GFile(embed_name_to_id_file, 'w') as fout:
    for tmp_norm_name in norm_name_to_ids:
      fout.write('%s\t%s\n' % (tmp_norm_name, norm_name_to_ids[tmp_norm_name]))
  ops.add_to_collection(
      ops.GraphKeys.ASSET_FILEPATHS,
      tf.constant(
          embed_name_to_id_file, dtype=tf.string, name='embed_name_to_ids.txt'))

  if 'incr_update' in oss_params:
    dense_train_vars_path = os.path.join(
        os.path.dirname(checkpoint_path), constant.DENSE_UPDATE_VARIABLES)
    ops.add_to_collection(
        ops.GraphKeys.ASSET_FILEPATHS,
        tf.constant(
            dense_train_vars_path,
            dtype=tf.string,
            name=constant.DENSE_UPDATE_VARIABLES))

    asset_file = 'incr_update.txt'
    asset_file_path = os.path.join(export_dir, asset_file)
    with GFile(asset_file_path, 'w') as fout:
      incr_update = oss_params['incr_update']
      incr_update_json = {}
      if 'kafka' in incr_update:
        incr_update_json['storage'] = 'kafka'
        incr_update_json['kafka'] = json.loads(
            json_format.MessageToJson(
                incr_update['kafka'], preserving_proto_field_name=True))
      elif 'datahub' in incr_update:
        incr_update_json['storage'] = 'datahub'
        incr_update_json['datahub'] = json.loads(
            json_format.MessageToJson(
                incr_update['datahub'], preserving_proto_field_name=True))
      elif 'fs' in incr_update:
        incr_update_json['storage'] = 'fs'
        incr_update_json['fs'] = {'incr_save_dir': incr_update['fs'].mount_path}
      json.dump(incr_update_json, fout, indent=2)

    ops.add_to_collection(
        ops.GraphKeys.ASSET_FILEPATHS,
        tf.constant(asset_file_path, dtype=tf.string, name=asset_file))

  export_dir = os.path.join(export_dir,
                            meta_graph_def.meta_info_def.meta_graph_version)
  export_dir = io_util.fix_oss_dir(export_dir)
  logging.info('export_dir=%s' % export_dir)
  if Exists(export_dir):
    logging.info('will delete old dir: %s' % export_dir)
    DeleteRecursively(export_dir)

  builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
  tensor_info_inputs = {}
  for tmp_key in inputs:
    tmp = graph.get_tensor_by_name(inputs[tmp_key].name)
    tensor_info_inputs[tmp_key] = \
        tf.saved_model.utils.build_tensor_info(tmp)

  tensor_info_outputs = {}
  for tmp_key in outputs:
    tmp = graph.get_tensor_by_name(outputs[tmp_key].name)
    tensor_info_outputs[tmp_key] = \
        tf.saved_model.utils.build_tensor_info(tmp)
  signature = (
      tf.saved_model.signature_def_utils.build_signature_def(
          inputs=tensor_info_inputs,
          outputs=tensor_info_outputs,
          method_name=signature_constants.PREDICT_METHOD_NAME))

  if 'incr_update' in oss_params:
    incr_update_inputs = meta_graph_editor.sparse_update_inputs
    incr_update_outputs = meta_graph_editor.sparse_update_outputs
    incr_update_inputs.update(meta_graph_editor.dense_update_inputs)
    incr_update_outputs.update(meta_graph_editor.dense_update_outputs)
    tensor_info_incr_update_inputs = {}
    tensor_info_incr_update_outputs = {}
    for tmp_key in incr_update_inputs:
      tmp = graph.get_tensor_by_name(incr_update_inputs[tmp_key].name)
      tensor_info_incr_update_inputs[tmp_key] = \
          tf.saved_model.utils.build_tensor_info(tmp)
    for tmp_key in incr_update_outputs:
      tmp = graph.get_tensor_by_name(incr_update_outputs[tmp_key].name)
      tensor_info_incr_update_outputs[tmp_key] = \
          tf.saved_model.utils.build_tensor_info(tmp)
    incr_update_signature = (
        tf.saved_model.signature_def_utils.build_signature_def(
            inputs=tensor_info_incr_update_inputs,
            outputs=tensor_info_incr_update_outputs,
            method_name=signature_constants.PREDICT_METHOD_NAME))
  else:
    incr_update_signature = None

  session_config = ConfigProto(
      allow_soft_placement=True, log_device_placement=True)

  saver = tf.train.Saver()
  with tf.Session(target=server.target if server else '') as sess:
    saver.restore(sess, checkpoint_path)
    main_op = tf.group([
        Scaffold.default_local_init_op(),
        ops.get_collection(EMBEDDING_INITIALIZERS)
    ])
    incr_update_sig_map = {
        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature
    }
    if incr_update_signature is not None:
      incr_update_sig_map[INCR_UPDATE_SIGNATURE_KEY] = incr_update_signature
    builder.add_meta_graph_and_variables(
        sess, [tf.saved_model.tag_constants.SERVING],
        signature_def_map=incr_update_sig_map,
        assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS),
        saver=saver,
        main_op=main_op,
        strip_default_attrs=True,
        clear_devices=True)
    builder.save()

  # remove temporary files
  Remove(embed_name_to_id_file)
  return export_dir