05 - Tensorflow 模型保存与恢复
本文大部分是对官网的描述做些备注和个人理解。如有需要,请直接查看官网原文。
https://tensorflow.google.cn/programmers_guide/saved_model (英文原版)
https://tensorflow.google.cn/programmers_guide/saved_model?hl=zh-CN (中文原版)
1. 保存与恢复变量 (即 Checkpoint)
1.1 保存变量
import tensorflow as tf # build network # ... # define checkpoint saver saver = tf.train.Saver() with tf.Session() as sess: # init variables sess.run(tf.global_variables_initializer()) # 保存 checkpoint saver.save(sess, save_path="/path/to/save/checkpoints", global_step=tf.train.get_global_step())
说明: save_path 是到文件名层面的。若保存路径设为“/data/model”,则 checkpoint 的文件以“/data/model-{global_step}“开头。一次会生成三个文件,分别以”data-00000-of-00001”、“index”、“meta”为扩展名。
如 save_path 为“/data/model”,step 为 100 时,生成“/data/model-100.data-00000-of-00001”、“/data/model-100.index”、“/data/model-100.meta”三个文件。
使用 `estimator` 时,开发者不需要关心 checkpoint 的保存和恢复。 `estimator` 只需要指定保存路径,即文件夹层次,保存的文件名默认为“model.ckpt”,因此生成出来的文件名为“model.ckpt-{global_step}”
1.2 恢复变量
import tensorflow as tf # build network # ... # define checkpoint saver saver = tf.train.Saver() with tf.Session() as sess: # you don't need to init variables. 不需要初始化变量,因为这些变量直接从checkpoint中恢复。(当然,先初始化也不会有问题。) #sess.run(tf.global_variables_initializer()) # 恢复 checkpoint saver.restore(sess, save_path="/path/to/save/checkpoints")
说明: 与保存一样,save_path 是到文件名层面的,不同的是,这里是要指定到 global_step。
如之前保存过“/data/model-100.data-00000-of-00001”,则恢复时 save_path 为 “/data/model-100”
使用 `estimator` 时,开发者不需要关心 checkpoint 的保存和恢复。 `estimator` 只需要指定保存路径,即文件夹层次,默认会从该目录下的 “checkpoint” 文件中取到最新的 checkpoint 文件名(通过 step 数的大小),然后加载恢复。
注意: checkpoint 数据中不包含网络模型,只包含超参数等。因此恢复的时候,同样需要先创建同样的网络模型,才能恢复成功。
1.3 选择部分变量来进行保存和恢复
有时候并不需要保存和恢复所有变量,这时可以手动指定需要保存和恢复的变量名。
示例不展开,请参考官网:
1.4 查看 checkpoint 文件中的变量值
官网示例,使用 `inspect_checkpoint` 库快速检查某个 checkpoint 中的变量。
# import the inspect_checkpoint library from tensorflow.python.tools import inspect_checkpoint as chkp # print all tensors in checkpoint file chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='', all_tensors=True) # tensor_name: v1 # [ 1. 1. 1.] # tensor_name: v2 # [-1. -1. -1. -1. -1.] # print only tensor v1 in checkpoint file chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v1', all_tensors=False) # tensor_name: v1 # [ 1. 1. 1.] # print only tensor v2 in checkpoint file chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v2', all_tensors=False) # tensor_name: v2 # [-1. -1. -1. -1. -1.]
2. 保存与恢复模型 (即 pb 文件)
2.1 保存模型
2.1.1 简单保存
# 直接调用 simple_save tf.saved_model.simple_save(session, export_dir, inputs={"x": x, "y": y}, outputs={"z": z})
说明: 这里的 export_dir 是文件夹层次的。所有跟此模型相关的数据都在该文件夹中。
2.1.2 手动保存
import tensorflow as tf builder = tf.saved_model.builder.SavedModelBuilder(export_dir="/path/to/save/model") with tf.Session(graph=tf.Graph()) as sess: # ... builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING], strip_default_attrs=True) builder.save()
说明: 参数的说明请参数官网
2.2 加载模型
import tensorflow as tf with tf.Session(graph=tf.Graph()) as sess: tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], export_dir="/path/to/save/model")
说明: 第二个参数 tag 标签要和保存时的一致。模型中已经包含了网络模型和参数,加载完之后可以直接使用。输入输出对应的 tensor 需要通过 `sess.graph.get_tensor_by_name` 方法获取。
2.3 使用 Estimator 保存模型
2.3.1 输入函数的定制
官方示例:
feature_spec = {'foo': tf.FixedLenFeature(...), 'bar': tf.VarLenFeature(...)} def serving_input_receiver_fn(): """An input receiver that expects a serialized tf.Example.""" serialized_tf_example = tf.placeholder(dtype=tf.string, shape=[default_batch_size], name='input_example_tensor') receiver_tensors = {'examples': serialized_tf_example} features = tf.parse_example(serialized_tf_example, feature_spec) return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
个人测试示例:
def serving_input_receiver_fn(): serialized_tf_example = tf.placeholder(dtype=tf.float32, shape=[None, FLAGS.img_width, FLAGS.img_height, 3], name='input_tensors') return tf.estimator.export.ServingInputReceiver({"image_data": serialized_tf_example}, {"predictor_inputs": serialized_tf_example})
说明: 待补充。 这里各个值的意思还有点迷糊,后续补充介绍完整。
2.3.2 导出模型
estimator.export_savedmodel(export_dir_base, serving_input_receiver_fn, strip_default_attrs=True)
说明: 它在给定的 export_dir_base(即 export_dir_base/<timestamp>)下面创建一个带时间戳的导出目录,并将 SavedModel 写入其中。
2.3.3 自定义 Estimator 中的指定输出
如果是自定义的 Estimator,则可以在 model_fn 的返回值 tf.estimator.EstimatorSpec 中加上 export_outputs 参数,该参数是一个字典。
个人测试示例:
def my_model_fn(features, labels, mode): # ... # predict = logits if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'probabilities': tf.nn.softmax(logits), 'predict': predict } export_outputs = {"serving_default": tf.estimator.export.PredictOutput(outputs=predictions)} return tf.estimator.EstimatorSpec(mode, predictions=predictions, export_outputs=export_outputs) # ...
说明: export_outputs 一般需要有一个键为 “serving_default” 的值,这个是一个 tensorflow 的常量, 定义在 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY。
如果 export_outputs 没给值,则 Estimator 内部会自动生成一个如上面示例一样的字典(key 和 value 也是一样的)。
注意:模型的加载是通用的。一般 Estimator 只负责训练和当场预测,当需要把模型导出时,已经是需要让模型运行在 server 上,因此 Estimator 没有额外的模型加载方法。
2.4 查看模型中的信息
2.4.1 SaveModel 目录结构
SaveModel 保存出来的目录结构如下:
assets/ assets.extra/ variables/ variables.data-?????-of-????? variables.index saved_model.pb saved_model.pbtxt
2.4.2 CLI 工具
可以使用 tensorflow 自带的 cli 工具查看或者运行 SaveModel。
saved_model_cli 有三种命令: show, run, scan
可分别运行以下命令查看各命令的帮助文档:
saved_model_cli show -h saved_model_cli run -h saved_model_cli scan -h
这里只介绍 show 的功能,其他功能需自行研究。
2.4.2.1 show
# 只显示 tag 集合: saved_model_cli show --dir /path/to/model_dir # 显示特定 tag 的 SignatureDefs key 集合: saved_model_cli show --dir /path/to/model_dir --tag_set <tag> # 显示特定 key 的 SignatureDef: saved_model_cli show --dir /path/to/model_dir --tag_set <tag> --signature_def <key> # 显示所有信息: saved_model_cli show --dir /path/to/model_dir --all
说明: SignatureDefs 信息包含了 input 和 output 的 tensor 信息,便于在加载模型时,知道相应的 tensor 的名称。
更多关于 SignatureDefs 的相关情况,请查看 tensorflow server 相关信息。