- 1 # author by FH.
- 2 # OverView:
- 3 # tf.data
- 4 # experimental ---Modules
- 5 # Dataset ---class
- 6 # FixedLengthRecordDataset ---class
- 7 # TFRecordDataset ---class
- 8 # TextLineDataset ---class
- 9 import tensorflow as tf
- 10 import numpy as np
- 11
- 12
- 13 # 1. 使用静态方法 tf.data.Dataset.from_tensor_slices
- 14 # 将输入的第一个维度切割,形成dataset
- 15 # 2. 使用 Dataset的 make_one_shot_iterator() 实例化一个 iterator
- 16 # 这个iterator 只能从头到尾读取一次。“one shot iterator”
- 17 def test1():
- 18 sess = tf.Session()
- 19 dataset1 = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0]))
- 20 dataset2 = tf.data.Dataset.from_tensor_slices(np.array([[1,2],[3,4],[0,9]]))
- 21 dataset3 = tf.data.Dataset.from_tensor_slices(
- 22 {
- 23 "a":np.array([1.0,2,3,4,5.0]),
- 24 "b":np.random.uniform(size=(5,2))
- 25 }
- 26 )
- 27 # 使用 Dataset的 make_one_shot_iterator() 实例化一个 iterator
- 28 # 这个iterator 只能从头到尾读取一次。“one shot iterator”
- 29 oneShotIterator1 = dataset1.make_one_shot_iterator()
- 30 oneShotIterator2 = dataset2.make_one_shot_iterator()
- 31 oneShotIterator3 = dataset3.make_one_shot_iterator()
- 32 element1 = oneShotIterator1.get_next()
- 33 element2 = oneShotIterator2.get_next()
- 34 element3 = oneShotIterator3.get_next()
- 35 for i in range(5):
- 36 print(sess.run(element1))
- 37 for i in range(3):
- 38 print(sess.run(element2))
- 39 for i in range(5):
- 40 print(sess.run(element3))
- 41 sess.close()
- 42
- 43 # 1.Dataset 中的数据元素转换。
- 44 # map() :参数为一个函数,将dataset中的每个元素带入获取新的值
- 45 # batch(): 参数为一个整数,将多个元素组合成一个batch
- 46 def test2():
- 47 sess = tf.Session()
- 48 dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0,6]))
- 49 # map() 重新映射新的元素值
- 50 dataset1 = dataset.map(lambda x: x * 3)
- 51 # batch() 2个组成一个batch, 组成batch 之后size 为3
- 52 dataset2 = dataset.batch(2)
- 53 # shuffle() 打乱dataset
- 54 dataset3 = dataset.shuffle(buffer_size=3)
- 55 # repeat() 将整个序列重复多次,重复4次 size 为24
- 56 dataset4 = dataset.repeat(4)
- 57
- 58 oneShotIterator1 = dataset1.make_one_shot_iterator()
- 59 oneShotIterator2 = dataset2.make_one_shot_iterator()
- 60 oneShotIterator3 = dataset3.make_one_shot_iterator()
- 61 oneShotIterator4 = dataset4.make_one_shot_iterator()
- 62 element1 = oneShotIterator1.get_next()
- 63 element2 = oneShotIterator2.get_next()
- 64 element3 = oneShotIterator3.get_next()
- 65 element4 = oneShotIterator4.get_next()
- 66 for i in range(6): # map()
- 67 print(sess.run(element1))
- 68 for i in range(3): # batch()
- 69 print(sess.run(element2))
- 70 for i in range(6): # shuffle()
- 71 print(sess.run(element3))
- 72 for i in range(24): # repeat()
- 73 print(sess.run(element4))
- 74 sess.close()
- 75
- 76 # example1: 读取图片和相应的标签并打乱,组成
- 77 # batch_size=2 的数据集,重复10 epoch
- 78 def _parse_function(imgfilename,label):
- 79 image_value = tf.read_file(imgfilename)
- 80 img = tf.image.decode_image(image_value)
- 81 img = tf.image.resize_images(img,[256,256])
- 82 return img,label
- 83 def example1():
- 84 # 图片列表
- 85 filesnames = tf.constant(['name1.jpg','name3.jpg','name5.jpg','name6.jpg','name7.jpg','name8.jpg'])
- 86 # 对应标签
- 87 labels = tf.constant([0,1,0,1,1,0])
- 88 # dataset (名称,标签)
- 89 dataset = tf.data.Dataset.from_tensor_slices((filesnames,labels))
- 90 # map 映射成图片和标签
- 91 dataset = dataset.map(_parse_function)
- 92 # shuffle ,batch , repeat
- 93 dataset = dataset.shuffle(buffersize=3).batch(2).repeat(10)
- 94 return dataset
- 95
- 96 if __name__ == '__main__':
- 97 test2()