in lingvo/core/inference_graph_exporter.py [0:0]
def Export(cls,
model_cfg,
model_task_name=None,
device_options=InferenceDeviceOptions(
device='',
retain_device_placement=False,
var_options=None,
gen_init_op=True,
dtype_override=None,
fprop_dtype_override=None),
freeze_checkpoint=None,
freeze_defaults=False,
export_path=None,
subgraph_filter=None,
random_seed=None,
disable_packed_input=True,
prune_graph=True,
export_graph_collections=False):
"""Exports a InferenceGraph proto with piecewise subgraphs.
Sets FLAGS.enable_asserts to False unless user explicitly sets it to True.
Note: Enable FLAGS.pin_vars_to_cpu (default false) to make weight-sharing
and multi-core inference on TPUs work properly.
Args:
model_cfg: a Params instance as returned by
model_registry.GetParams(modelname, 'Test') or model_params.Model().
model_task_name: The task to generate an inference graph for. Should be
None for single-task models.
device_options: Device options for the accelerator used for serving.
freeze_checkpoint: The checkpoint to load. Loads and freezes the model if
given.
freeze_defaults: Default initializes the graph and freeze. Useful for
early testing of downstream tools without having a checkpoint.
export_path: If not None, write the inference graph in ASCII to this path.
subgraph_filter: A string or a list of subgraph names. If not None or
empty, export only this list of inference subgraphs.
random_seed: Fixes the random seed in the exported inference graph.
disable_packed_input: Disable packed input for inference writing purposes.
prune_graph: If true, prune the graph to just the parts we need.
export_graph_collections: If true, export graph collections to the
InferenceGraph proto.
Returns:
InferenceGraph proto.
Raises:
ValueError: if the model does not support the listed subgraphs.
"""
if py_utils.IsEagerMode():
raise ValueError('InferenceGraph exporter does not work in Eager mode.')
assert issubclass(model_cfg.cls, base_model.BaseModel)
if device_options.dtype_override and device_options.fprop_dtype_override:
raise ValueError(
'device_options{dtype_override,fprop_dtype_override) can not both be'
'set.')
if subgraph_filter and not isinstance(subgraph_filter, (tuple, list)):
subgraph_filter = [subgraph_filter]
# Disable assertions unless user explicitly enables it.
if FLAGS['enable_asserts'].using_default_value:
FLAGS.enable_asserts = False
# TODO(laurenzo): Work out how much we need to specify here in terms of
# cluster configuration.
cls._SetClusterParams(model_cfg.cluster, device_options)
# Configure the model.
model_cfg.random_seed = random_seed
model_cfg.is_inference = True
if disable_packed_input:
def _DisablePackedInput(task):
if (_ParamExists(task, 'encoder') and
_ParamExists(task.encoder, 'packed_input')):
task.encoder.packed_input = False
if (_ParamExists(task, 'decoder') and
_ParamExists(task.decoder, 'packed_input')):
task.decoder.packed_input = False
if issubclass(model_cfg.cls, base_model.MultiTaskModel):
for _, task_param in model_cfg.task_params.IterParams():
_DisablePackedInput(task_param)
else:
_DisablePackedInput(model_cfg.task)
tf.logging.debug('Model %s params:', model_cfg.name)
for line in model_cfg.ToText().split('\n'):
tf.logging.debug('%s', line)
# Instantiate the graph.
graph = tf.Graph()
with graph.as_default():
tf.random.set_seed(random_seed)
cluster = model_cfg.cluster.Instantiate()
device = cluster.GetPlacer()
tpu_const_scope = _DummyScope()
if (IsTpu(device_options) and
device_options.var_options == 'AS_CONSTANTS'):
# Do not specify devices for variables if we are marking them as
# constants.
device = ''
tpu_const_scope = ConstGuaranteeScope()
with cluster, tf.device(device), tpu_const_scope:
bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations(
device_options)
if bfloat16_override:
py_utils.UpdateDtype(model_cfg, tf.bfloat16)
py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16)
act_bfloat16_override = ShouldForceBfloat16ForActivations(
device_options)
if act_bfloat16_override:
py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16)
# Hard-code TPU-related flags prior to instantiating model.
old_enable_asserts = FLAGS.enable_asserts
old_xla_device = FLAGS.xla_device
if IsTpu(device_options):
FLAGS.enable_asserts = False
FLAGS.xla_device = 'tpu'
try:
mdl = model_cfg.Instantiate()
task = mdl.GetTask(model_task_name)
variables_to_restore = (
_MakeVariableDictionary(tf.global_variables()) if not mdl.ema else
mdl.ema.variables_to_restore(mdl.variables_for_ema))
if bfloat16_override:
saver_var_spec = (
bfloat16_variables
.get_saver_spec_for_variables_with_bf16_overrides(
variables_to_restore))
# For TPU embedding layers, if the table explicitly specifies the
# inference dtype as bfloat16, the variables in the checkpoint must
# already be in bfloat16, so we change back to bfloat16 to avoid
# dtype mismatch.
for var_name in (tpu_embedding_layers.TpuEmbeddingCollection.Get()
.inference_with_bfloat16_var_names):
saver_var_spec[var_name] = variables_to_restore[var_name]
else:
saver_var_spec = variables_to_restore
saver = tf.train.Saver(saver_var_spec)
tf.variables_initializer(
tf.global_variables(), name='init_all_variables')
if IsTpu(device_options) and device_options.gen_init_op:
tf.group(tf.tpu.initialize_system(), name='tpu_init_op')
if freeze_checkpoint or freeze_defaults:
# Replace variables with tensors using tf.identity in theta before
# freezing to avoid the graph referencing types of DT_RESOURCE.
def AddIdentityToTheta(layer):
# pylint: disable=protected-access
layer._private_theta = py_utils.Transform(tf.identity,
layer._private_theta)
# pylint: enable=protected-access
layer.children.Transform(AddIdentityToTheta)
AddIdentityToTheta(task)
inference_graph_proto = inference_graph_pb2.InferenceGraph()
subgraphs_proto = task.Inference()
if isinstance(subgraphs_proto, dict):
subgraphs_proto = ConvertSubgraphDictToProto(subgraphs_proto)
for name, subgraph in subgraphs_proto.subgraphs.items():
if not subgraph_filter or name in subgraph_filter:
inference_graph_proto.subgraphs[name].CopyFrom(subgraph)
if not inference_graph_proto.subgraphs and subgraph_filter:
raise ValueError(
f'Subgraph filters {subgraph_filter} filtered out all '
'subgraphs. Defined subgraphs: '
f'{list(subgraphs_proto.subgraphs.keys())}')
# Yes, graph collections are bad, however this seems to be the
# easiest way to get this assets registered from
# TextFileInitializer.
assets_collection = tf.compat.v1.get_collection(
tf.compat.v1.GraphKeys.ASSET_FILEPATHS)
for asset in assets_collection:
if asset.op.type == 'Const' and asset.op.get_attr(
'dtype') == tf.dtypes.string:
constant_value = asset.op.get_attr('value')
if constant_value.string_val:
tf.logging.info('Found asset file_path: %s',
constant_value.string_val[0])
asset_file_def = inference_graph_proto.asset_file_def.add()
asset_file_def.tensor_info.name = asset.name
asset_file_def.filename = constant_value.string_val[0]
# Add a table init op and global variable init op to the graph.
# Tables can be declared anywhere in the graph, so this op has to be
# added last.
tf.tables_initializer(name='init_all_tables')
finally:
# Reset TPU-related flags after model instantiation.
FLAGS.enable_asserts = old_enable_asserts
FLAGS.xla_device = old_xla_device
tf.logging.info('Graph contains ops: %r',
[op.name for op in graph.get_operations()])
# Collection defs
if not tf.executing_eagerly():
if export_graph_collections:
meta_graph = tf.train.export_meta_graph(graph=graph)
for key in meta_graph.collection_def:
tf.logging.info('copying collection %s', key)
inference_graph_proto.collection_def[key].CopyFrom(
meta_graph.collection_def[key])
else:
tf.logging.warning('Not exporting collection defs '
'since operating in eager mode.')
# Freezing.
if freeze_defaults or freeze_checkpoint:
output_op_names = GetOutputOpNames(
graph,
inference_graph_proto,
preserve_colocation_nodes=False,
preserve_saver_restore_nodes=False)
if cls._DeviceSupportsFreezing(device_options):
raise ValueError('freeze_checkpoint cannot be used with device ' +
device_options.device)
if freeze_checkpoint:
tf.logging.info('Freezing graph from checkpoint: %s', freeze_checkpoint)
graph_def = _FreezeGraphFromCheckpoint(graph, saver, freeze_checkpoint,
output_op_names)
elif freeze_defaults:
tf.logging.info('Default initializing graph and freezing.')
graph_def = _FreezeDefaults(graph, output_op_names)
else:
inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def())
graph_def = graph.as_graph_def()
if prune_graph:
output_op_names = GetOutputOpNames(graph, inference_graph_proto)
# Prune the graph to just the parts we need.
# To support restoring, we have to not prune out the restore node.
output_op_names.append('init_all_tables')
output_op_names.append('init_all_variables')
output_op_names.append('save/control_dependency')
output_op_names.append('save/restore_all')
if IsTpu(device_options) and device_options.gen_init_op:
output_op_names.append('tpu_init_op')
tf.logging.info('Pruning graph to output ops: %r', output_op_names)
graph_def = tf.compat.v1.graph_util.extract_sub_graph(
graph_def, output_op_names)
if not device_options.retain_device_placement:
# Clear the device so that the runtime can choose.
tf.logging.info('Clearing device placement for: %s',
device_options.device)
for node in graph_def.node:
node.ClearField('device')
for function in graph_def.library.function:
for node_def in function.node_def:
node_def.ClearField('device')
inference_graph_proto.graph_def.CopyFrom(graph_def)
if export_path:
with tf.io.gfile.GFile(export_path, 'w') as f:
f.write(text_format.MessageToString(inference_graph_proto))
return inference_graph_proto