经验首页 前端设计 程序设计 Java相关 移动开发 数据库/运维 软件/图像 大数据/云计算 其他经验
当前位置:技术经验 » 大数据/云/AI » 人工智能基础 » 查看文章
记录:tf.saved_model 模块的简单使用(TensorFlow 模型存储与恢复)
来源:cnblogs  作者:买白菜不用券  时间:2018/12/3 9:49:26  对本文有异议

虽然说 TensorFlow 2.0 即将问世,但是有一些模块的内容却是不大变化的。其中就有 tf.saved_model 模块,主要用于模型的存储和恢复。为了防止学习记录文件丢失或者蠢笨的脑子直接遗忘掉这部分内容,在此做点简单的记录,以便将来查阅。

最近为了一个课程作业,不得已涉及到关于图像超分辨率恢复的内容,不得不准备随时存储训练的模型,只好再回过头来瞄一眼 TensorFlow 文档,真是太痛苦了。

tf.saved_model 模块下面有很多文件和函数,精力有限,只好选择于自己有用的东西来看,可能并不全面,望日后补上。

其中最重要的就是该模块下的一个类:tf.saved_model.builder.SavedModelBuilder

  1. tf.saved_model.builder.SavedModelBuilder:
  2. # 构造函数
  3. .__init__(export_dir)
  4. """
  5. 作用:
  6.   创建一个保存模型的实例对象
  7. 参数:
  8. export_dir: 模型导出路径,由于 TensorFlow 会在你指定的路径上创建文件夹和文件,所以指定的路径最后不需要带 /,
  9.    例如:export_dir='/home/***/saved_model' 即可,最后不需要加上 /
  10. """
  11.  
  12. # 方法
  13. # 1
  14. .add_meta_graph_and_variables(sess, tags, signature_def_map=None, assets_collection=None,
  15. clear_devices=False, main_op=None, strip_default_attrs=False, saver=None)
  16. """
  17. 作用:
  18.   保存会话对象中的 graph 和所有变量,具体描述可参见文档
  19. 参数:
  20.   sess: TensorFlow 会话对象,用于保存元图和变量
  21.   tags: 用于保存元图的标记集(如果存在多个图对象,需要设置保证每个图标签不一样),是一个列表
  22.   signature_def_map: 一个字典,保存模型时传入的参数,key 可以是字符串,也可以是 tf.saved_model.signature_constants 文件下预定义的变量,
  23. 值为 signatureDef protobuf(protobuf 是一种结构化的数据存储格式)
  24.   assets_collection: 略
  25.   clear_devices: 如果需要清除默认图上的设备信息,则设置为 true
  26.   main_op: 这个参数包括后面一系列与其相关的东西没有弄明白
  27.   strip_default_attrs: 如果设置为 True,将从 NodeDefs 中删除默认值属性
  28.   saver: tf.train.Saver 的一个实例,用于导出元图并保存变量
  29. """
  30.  
  31. # 2
  32. .add_meta_graph()
  33. """
  34. 作用:
  35.   其除了没有 sess 参数以外,其他参数和 .add_meta_graph_and_variables() 一模一样
  36.   调用此方法之前必须先调用 .add_meta_graph_and_variables() 方法
  37. """
  38.  
  39. # 3
  40. .save(as_text=False)
  41. """
  42. 作用:
  43.   将内建的 savedModel protobuf 写入磁盘
  44. """

除了这个最重要的类以外,tf.saved_model 模块还提供了一些方便构建 builder 和加载模型的函数方法。

  1. # 1
  2. tf.saved_model.utils.build_tensor_info(tensor)
  3. """
  4. 作用:
  5. 构建 TensorInfo protobuf,根据输入的 tensor 构建相应的 protobuf,返回的 TensorInfo 中包含输入 tensor 的 name,shape,dtype 信息
  6. 参数:
  7. tensor: Tensor 或 SparseTensor
  8. """
  9.  
  10. # 2
  11. tf.saved_model.signature_def_utils.build_signature_def(inputs=None, outputs=None, method_name=None)
  12. """
  13. 作用:
  14. 构建 SignatureDef protobuf,并返回 SignatureDef protobuf
  15. 参数:
  16. inputs: 一个字典,键为字符串类型,值为关于 tensor 的信息,也就是上述的 .build_tensor_info() 函数返回的 TensorInfo protobuf
  17. outputs: 一个字典,同上
  18. method_name: SignatureDef 名称
  19. """
  20.  
  21. # 3
  22. tf.saved_model.utils.get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None)
  23. """
  24. 作用:
  25. 根据一个 TensorInfo protobuf 解析出一个 tensor
  26. 参数:
  27. tensor_info: 一个 TensorInfo protobuf
  28. graph: tensor 所存在的 graph,参数为 None 时,使用默认图
  29. import_scope: 给 tensor 的 name 加上前缀
  30. """
  31.  
  32. # 4
  33. tf.saved_model.loader.load(sess, tags, export_dir, import_scope=None, **saver_kwargs)
  34. """
  35. 作用:
  36. 加载已存储的模型
  37. 参数:
  38. sess: 用于恢复模型的 tf.Session() 对象
  39. tags: 用于标识 MetaGraphDef 的标记,应该和存储模型时使用的此参数完全一致
  40. export_dir: 模型存储路径
  41. import_scope: 加前缀
  42. """

除了这些以外,还有一些 TensorFlow 为了方便而预定义的一些变量,这些变量完全可以使用自定义字符串代替,不再赘述。详情:https://tensorflow.google.cn/api_docs/python/tf/saved_model

如果只看这些内容的话,确实会使人产生巨大的疑惑,下面是具体实践的例子:

  1. import tensorflow as tf
  2. from tensorflow import saved_model as sm
  3. # 首先定义一个极其简单的计算图
  4. X = tf.placeholder(tf.float32, shape=(3, ))
  5. scale = tf.Variable([10, 11, 12], dtype=tf.float32)
  6. y = tf.multiply(X, scale)
  7. # 在会话中运行
  8. with tf.Session() as sess:
  9. sess.run(tf.initializers.global_variables())
  10. value = sess.run(y, feed_dict={X: [1., 2., 3.]})
  11. print(value)
  12. # 准备存储模型
  13. path = '/home/×××/tf_model/model_1'
  14. builder = sm.builder.SavedModelBuilder(path)
  15. # 构建需要在新会话中恢复的变量的 TensorInfo protobuf
  16. X_TensorInfo = sm.utils.build_tensor_info(X)
  17. scale_TensorInfo = sm.utils.build_tensor_info(scale)
  18. y_TensorInfo = sm.utils.build_tensor_info(y)
  19. # 构建 SignatureDef protobuf
  20. SignatureDef = sm.signature_def_utils.build_signature_def(
  21. inputs={'input_1': X_TensorInfo, 'input_2': scale_TensorInfo},
  22. outputs={'output': y_TensorInfo},
  23. method_name='what'
  24. )
  25. # 将 graph 和变量等信息写入 MetaGraphDef protobuf
  26. # 这里的 tags 里面的参数和 signature_def_map 字典里面的键都可以是自定义字符串,TensorFlow 为了方便使用,不在新地方将自定义的字符串忘记,可以使用预定义的这些值
  27. builder.add_meta_graph_and_variables(sess, tags=[sm.tag_constants.TRAINING],
  28. signature_def_map={sm.signature_constants.CLASSIFY_INPUTS: SignatureDef}
  29. )
  30.  # 将 MetaGraphDef 写入磁盘
  31. builder.save()

这样我们就把模型整体存储到了磁盘中,而且我们将三个变量 X, scale, y 全部序列化后存储到了其中,所以恢复模型时便可以将他们完全解析出来:

  1. import tensorflow as tf
  2. from tensorflow import saved_model as sm
  3. # 需要建立一个会话对象,将模型恢复到其中
  4. with tf.Session() as sess:
  5. path = '/home/×××/tf_model/model_1'
  6. MetaGraphDef = sm.loader.load(sess, tags=[sm.tag_constants.TRAINING], export_dir=path)
  7. # 解析得到 SignatureDef protobuf
  8. SignatureDef_d = MetaGraphDef.signature_def
  9. SignatureDef = SignatureDef_d[sm.signature_constants.CLASSIFY_INPUTS]
  10. # 解析得到 3 个变量对应的 TensorInfo protobuf
  11. X_TensorInfo = SignatureDef.inputs['input_1']
  12. scale_TensorInfo = SignatureDef.inputs['input_2']
  13. y_TensorInfo = SignatureDef.outputs['output']
  14. # 解析得到具体 Tensor
  15. # .get_tensor_from_tensor_info() 函数中可以不传入 graph 参数,TensorFlow 自动使用默认图
  16. X = sm.utils.get_tensor_from_tensor_info(X_TensorInfo, sess.graph)
  17. scale = sm.utils.get_tensor_from_tensor_info(scale_TensorInfo, sess.graph)
  18. y = sm.utils.get_tensor_from_tensor_info(y_TensorInfo, sess.graph)
  19. print(sess.run(scale))
  20. print(sess.run(y, feed_dict={X: [3., 2., 1.]}))
  21. # 输出
  22. [10. 11. 12.]
  23. [30. 22. 12.]

可以看出模型整体和变量个体都被完整地保存了下来。其中涉及的关于 protobuf 的知识,需要补习,在 TensorFlow 中好多地方都用到了相关的知识。上述恢复模型的代码中对具体的 TensorInfo protobuf 解析时,还可以使用另一种方式得到相应的 Tensor:

  1. # 已知 X_TensorInfo, scale_TensorInfo, y_TensorInfo
  2. X = sess.graph.get_tensor_by_name(X_TensorInfo.name)
  3. scale = sess.grpah.get_tensor_by_name(scale_TensorInfo.name)
  4. y = sess.graph.get_tensor_by_name(y_TensorInfo.name)
  5. # 因为 TensorFlow 构建 TensorInfo protobuf 时,使用了 Tensor 的 name 信息,所以可以直接读出来使用

 

 友情链接:直通硅谷  点职佳  北美留学生论坛

本站QQ群:前端 618073944 | Java 606181507 | Python 626812652 | C/C++ 612253063 | 微信 634508462 | 苹果 692586424 | C#/.net 182808419 | PHP 305140648 | 运维 608723728

W3xue 的所有内容仅供测试,对任何法律问题及风险不承担任何责任。通过使用本站内容随之而来的风险与本站无关。
关于我们  |  意见建议  |  捐助我们  |  报错有奖  |  广告合作、友情链接(目前9元/月)请联系QQ:27243702 沸活量
皖ICP备17017327号-2 皖公网安备34020702000426号