public:tensorflow:saved_model

05 - Tensorflow 模型保存与恢复

本文大部分是对官网的描述做些备注和个人理解。如有需要,请直接查看官网原文。

https://tensorflow.google.cn/programmers_guide/saved_model (英文原版)

https://tensorflow.google.cn/programmers_guide/saved_model?hl=zh-CN (中文原版)

import tensorflow as tf
 
# build network
# ...
 
# define checkpoint saver
saver = tf.train.Saver()
 
with tf.Session() as sess:
    # init variables
    sess.run(tf.global_variables_initializer())
 
    # 保存 checkpoint
    saver.save(sess, save_path="/path/to/save/checkpoints", global_step=tf.train.get_global_step())

说明: save_path 是到文件名层面的。若保存路径设为“/data/model”,则 checkpoint 的文件以“/data/model-{global_step}“开头。一次会生成三个文件,分别以”data-00000-of-00001”、“index”、“meta”为扩展名。

如 save_path 为“/data/model”,step 为 100 时,生成“/data/model-100.data-00000-of-00001”、“/data/model-100.index”、“/data/model-100.meta”三个文件。

使用 `estimator` 时,开发者不需要关心 checkpoint 的保存和恢复。 `estimator` 只需要指定保存路径,即文件夹层次,保存的文件名默认为“model.ckpt”,因此生成出来的文件名为“model.ckpt-{global_step}”

import tensorflow as tf
 
# build network
# ...
 
# define checkpoint saver
saver = tf.train.Saver()
 
with tf.Session() as sess:
    # you don't need to init variables. 不需要初始化变量,因为这些变量直接从checkpoint中恢复。(当然,先初始化也不会有问题。)
    #sess.run(tf.global_variables_initializer())
 
    # 恢复 checkpoint
    saver.restore(sess, save_path="/path/to/save/checkpoints")

说明: 与保存一样,save_path 是到文件名层面的,不同的是,这里是要指定到 global_step。

如之前保存过“/data/model-100.data-00000-of-00001”,则恢复时 save_path 为 “/data/model-100”

使用 `estimator` 时,开发者不需要关心 checkpoint 的保存和恢复。 `estimator` 只需要指定保存路径,即文件夹层次,默认会从该目录下的 “checkpoint” 文件中取到最新的 checkpoint 文件名(通过 step 数的大小),然后加载恢复。

注意: checkpoint 数据中不包含网络模型,只包含超参数等。因此恢复的时候,同样需要先创建同样的网络模型,才能恢复成功。

有时候并不需要保存和恢复所有变量,这时可以手动指定需要保存和恢复的变量名。

示例不展开,请参考官网:

https://tensorflow.google.cn/programmers_guide/saved_model

官网示例,使用 `inspect_checkpoint` 库快速检查某个 checkpoint 中的变量。

# import the inspect_checkpoint library
from tensorflow.python.tools import inspect_checkpoint as chkp
 
# print all tensors in checkpoint file
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='', all_tensors=True)
 
# tensor_name:  v1
# [ 1.  1.  1.]
# tensor_name:  v2
# [-1. -1. -1. -1. -1.]
 
# print only tensor v1 in checkpoint file
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v1', all_tensors=False)
 
# tensor_name:  v1
# [ 1.  1.  1.]
 
# print only tensor v2 in checkpoint file
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v2', all_tensors=False)
 
# tensor_name:  v2
# [-1. -1. -1. -1. -1.]

2.1.1 简单保存

# 直接调用 simple_save
tf.saved_model.simple_save(session,
            export_dir,
            inputs={"x": x, "y": y},
            outputs={"z": z})

说明: 这里的 export_dir 是文件夹层次的。所有跟此模型相关的数据都在该文件夹中。

2.1.2 手动保存

import tensorflow as tf
 
builder = tf.saved_model.builder.SavedModelBuilder(export_dir="/path/to/save/model")
with tf.Session(graph=tf.Graph()) as sess:
    # ...
    builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING], strip_default_attrs=True)
builder.save()

说明: 参数的说明请参数官网

import tensorflow as tf
 
with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], export_dir="/path/to/save/model")

说明: 第二个参数 tag 标签要和保存时的一致。模型中已经包含了网络模型和参数,加载完之后可以直接使用。输入输出对应的 tensor 需要通过 `sess.graph.get_tensor_by_name` 方法获取。

2.3.1 输入函数的定制

官方示例:

feature_spec = {'foo': tf.FixedLenFeature(...),
                'bar': tf.VarLenFeature(...)}
 
def serving_input_receiver_fn():
    """An input receiver that expects a serialized tf.Example."""
    serialized_tf_example = tf.placeholder(dtype=tf.string,
                                         shape=[default_batch_size],
                                         name='input_example_tensor')
    receiver_tensors = {'examples': serialized_tf_example}
    features = tf.parse_example(serialized_tf_example, feature_spec)
    return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

个人测试示例:

def serving_input_receiver_fn():
    serialized_tf_example = tf.placeholder(dtype=tf.float32, shape=[None, FLAGS.img_width, FLAGS.img_height, 3], name='input_tensors')
    return tf.estimator.export.ServingInputReceiver({"image_data": serialized_tf_example}, {"predictor_inputs": serialized_tf_example})

说明: 待补充。 这里各个值的意思还有点迷糊,后续补充介绍完整。

2.3.2 导出模型

estimator.export_savedmodel(export_dir_base, serving_input_receiver_fn,
                            strip_default_attrs=True)

说明: 它在给定的 export_dir_base(即 export_dir_base/<timestamp>)下面创建一个带时间戳的导出目录,并将 SavedModel 写入其中。

2.3.3 自定义 Estimator 中的指定输出

如果是自定义的 Estimator,则可以在 model_fn 的返回值 tf.estimator.EstimatorSpec 中加上 export_outputs 参数,该参数是一个字典。

个人测试示例:

def my_model_fn(features, labels, mode):
    # ...
    # predict = logits
 
    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'probabilities': tf.nn.softmax(logits),
            'predict': predict
        }
        export_outputs = {"serving_default": tf.estimator.export.PredictOutput(outputs=predictions)}
        return tf.estimator.EstimatorSpec(mode, predictions=predictions, export_outputs=export_outputs)
 
    # ...

说明: export_outputs 一般需要有一个键为 “serving_default” 的值,这个是一个 tensorflow 的常量, 定义在 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY。

如果 export_outputs 没给值,则 Estimator 内部会自动生成一个如上面示例一样的字典(key 和 value 也是一样的)。

注意:模型的加载是通用的。一般 Estimator 只负责训练和当场预测,当需要把模型导出时,已经是需要让模型运行在 server 上,因此 Estimator 没有额外的模型加载方法。

2.4.1 SaveModel 目录结构

SaveModel 保存出来的目录结构如下:

assets/
assets.extra/
variables/
    variables.data-?????-of-?????
    variables.index
saved_model.pb
saved_model.pbtxt

2.4.2 CLI 工具

可以使用 tensorflow 自带的 cli 工具查看或者运行 SaveModel。

saved_model_cli 有三种命令: show, run, scan

可分别运行以下命令查看各命令的帮助文档:

saved_model_cli show -h
saved_model_cli run  -h
saved_model_cli scan -h

这里只介绍 show 的功能,其他功能需自行研究。

2.4.2.1 show
# 只显示 tag 集合:
saved_model_cli show --dir /path/to/model_dir
 
# 显示特定 tag 的 SignatureDefs key 集合:
saved_model_cli show --dir /path/to/model_dir --tag_set <tag>
 
# 显示特定 key 的 SignatureDef:
saved_model_cli show --dir /path/to/model_dir --tag_set <tag> --signature_def <key>
 
# 显示所有信息:
saved_model_cli show --dir /path/to/model_dir --all

说明: SignatureDefs 信息包含了 input 和 output 的 tensor 信息,便于在加载模型时,知道相应的 tensor 的名称。

更多关于 SignatureDefs 的相关情况,请查看 tensorflow server 相关信息。

  • 最后更改: 2019/10/31 15:45
  • 由 Jinkin Liu