经验首页 前端设计 程序设计 Java相关 移动开发 数据库/运维 软件/图像 大数据/云计算 其他经验
当前位置:技术经验 » 大数据/云/AI » TensorFlow » 查看文章
tensorflow-- Dataset创建数据集对象
来源:cnblogs  作者:feihu_h  时间:2019/10/15 14:44:52  对本文有异议

tf.data模块包含:

  •  experimental 模块
  •  Dataset 类
  •  FixedLengthRecordDataset 类
  • TFRecordDataset 类
  • TextLineDataset 类
  1.  
  1. 1 # author by FH.
  2. 2 # OverView:
  3. 3 # tf.data
  4. 4 # experimental ---Modules
  5. 5 # Dataset ---class
  6. 6 # FixedLengthRecordDataset ---class
  7. 7 # TFRecordDataset ---class
  8. 8 # TextLineDataset ---class
  9. 9 import tensorflow as tf
  10. 10 import numpy as np
  11. 11
  12. 12
  13. 13 # 1. 使用静态方法 tf.data.Dataset.from_tensor_slices
  14. 14 # 将输入的第一个维度切割,形成dataset
  15. 15 # 2. 使用 Dataset的 make_one_shot_iterator() 实例化一个 iterator
  16. 16 # 这个iterator 只能从头到尾读取一次。“one shot iterator”
  17. 17 def test1():
  18. 18 sess = tf.Session()
  19. 19 dataset1 = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0]))
  20. 20 dataset2 = tf.data.Dataset.from_tensor_slices(np.array([[1,2],[3,4],[0,9]]))
  21. 21 dataset3 = tf.data.Dataset.from_tensor_slices(
  22. 22 {
  23. 23 "a":np.array([1.0,2,3,4,5.0]),
  24. 24 "b":np.random.uniform(size=(5,2))
  25. 25 }
  26. 26 )
  27. 27 # 使用 Dataset的 make_one_shot_iterator() 实例化一个 iterator
  28. 28 # 这个iterator 只能从头到尾读取一次。“one shot iterator”
  29. 29 oneShotIterator1 = dataset1.make_one_shot_iterator()
  30. 30 oneShotIterator2 = dataset2.make_one_shot_iterator()
  31. 31 oneShotIterator3 = dataset3.make_one_shot_iterator()
  32. 32 element1 = oneShotIterator1.get_next()
  33. 33 element2 = oneShotIterator2.get_next()
  34. 34 element3 = oneShotIterator3.get_next()
  35. 35 for i in range(5):
  36. 36 print(sess.run(element1))
  37. 37 for i in range(3):
  38. 38 print(sess.run(element2))
  39. 39 for i in range(5):
  40. 40 print(sess.run(element3))
  41. 41 sess.close()
  42. 42
  43. 43 # 1.Dataset 中的数据元素转换。
  44. 44 # map() :参数为一个函数,将dataset中的每个元素带入获取新的值
  45. 45 # batch(): 参数为一个整数,将多个元素组合成一个batch
  46. 46 def test2():
  47. 47 sess = tf.Session()
  48. 48 dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0,6]))
  49. 49 # map() 重新映射新的元素值
  50. 50 dataset1 = dataset.map(lambda x: x * 3)
  51. 51 # batch() 2个组成一个batch, 组成batch 之后size 为3
  52. 52 dataset2 = dataset.batch(2)
  53. 53 # shuffle() 打乱dataset
  54. 54 dataset3 = dataset.shuffle(buffer_size=3)
  55. 55 # repeat() 将整个序列重复多次,重复4次 size 为24
  56. 56 dataset4 = dataset.repeat(4)
  57. 57
  58. 58 oneShotIterator1 = dataset1.make_one_shot_iterator()
  59. 59 oneShotIterator2 = dataset2.make_one_shot_iterator()
  60. 60 oneShotIterator3 = dataset3.make_one_shot_iterator()
  61. 61 oneShotIterator4 = dataset4.make_one_shot_iterator()
  62. 62 element1 = oneShotIterator1.get_next()
  63. 63 element2 = oneShotIterator2.get_next()
  64. 64 element3 = oneShotIterator3.get_next()
  65. 65 element4 = oneShotIterator4.get_next()
  66. 66 for i in range(6): # map()
  67. 67 print(sess.run(element1))
  68. 68 for i in range(3): # batch()
  69. 69 print(sess.run(element2))
  70. 70 for i in range(6): # shuffle()
  71. 71 print(sess.run(element3))
  72. 72 for i in range(24): # repeat()
  73. 73 print(sess.run(element4))
  74. 74 sess.close()
  75. 75
  76. 76 # example1: 读取图片和相应的标签并打乱,组成
  77. 77 # batch_size=2 的数据集,重复10 epoch
  78. 78 def _parse_function(imgfilename,label):
  79. 79 image_value = tf.read_file(imgfilename)
  80. 80 img = tf.image.decode_image(image_value)
  81. 81 img = tf.image.resize_images(img,[256,256])
  82. 82 return img,label
  83. 83 def example1():
  84. 84 # 图片列表
  85. 85 filesnames = tf.constant(['name1.jpg','name3.jpg','name5.jpg','name6.jpg','name7.jpg','name8.jpg'])
  86. 86 # 对应标签
  87. 87 labels = tf.constant([0,1,0,1,1,0])
  88. 88 # dataset (名称,标签)
  89. 89 dataset = tf.data.Dataset.from_tensor_slices((filesnames,labels))
  90. 90 # map 映射成图片和标签
  91. 91 dataset = dataset.map(_parse_function)
  92. 92 # shuffle ,batch , repeat
  93. 93 dataset = dataset.shuffle(buffersize=3).batch(2).repeat(10)
  94. 94 return dataset
  95. 95
  96. 96 if __name__ == '__main__':
  97. 97 test2()
View Code

 

原文链接:http://www.cnblogs.com/feihu-h/p/11677443.html

 友情链接:直通硅谷  点职佳  北美留学生论坛

本站QQ群:前端 618073944 | Java 606181507 | Python 626812652 | C/C++ 612253063 | 微信 634508462 | 苹果 692586424 | C#/.net 182808419 | PHP 305140648 | 运维 608723728

W3xue 的所有内容仅供测试,对任何法律问题及风险不承担任何责任。通过使用本站内容随之而来的风险与本站无关。
关于我们  |  意见建议  |  捐助我们  |  报错有奖  |  广告合作、友情链接(目前9元/月)请联系QQ:27243702 沸活量
皖ICP备17017327号-2 皖公网安备34020702000426号