经验首页 前端设计 程序设计 Java相关 移动开发 数据库/运维 软件/图像 大数据/云计算 其他经验
当前位置:技术经验 » 大数据/云/AI » TensorFlow » 查看文章
TensorFlow卷积神经网络之使用训练好的模型识别猫狗图片
来源:jb51  时间:2019/3/15 8:35:19  对本文有异议

本文是Python通过TensorFlow卷积神经网络实现猫狗识别的姊妹篇,是加载上一篇训练好的模型,进行猫狗识别

本文逻辑:

  1. 我从网上下载了十几张猫和狗的图片,用于检验我们训练好的模型。
  2. 处理我们下载的图片
  3. 加载模型
  4. 将图片输入模型进行检验

代码如下:

  1. #coding=utf-8
  2. import tensorflow as tf
  3. from PIL import Image
  4. import matplotlib.pyplot as plt
  5. import input_data
  6. import numpy as np
  7. import model
  8. import os
  9. #从指定目录中选取一张图片
  10. def get_one_image(train):
  11. files = os.listdir(train)
  12. n = len(files)
  13. ind = np.random.randint(0,n)
  14. img_dir = os.path.join(train,files[ind])
  15. image = Image.open(img_dir)
  16. plt.imshow(image)
  17. plt.show()
  18. image = image.resize([208, 208])
  19. image = np.array(image)
  20. return image
  21. def evaluate_one_image():
  22. #存放的是我从百度下载的猫狗图片路径
  23. train = '/Users/yangyibo/GitWork/pythonLean/AI/猫狗识别/testImg/'
  24. image_array = get_one_image(train)
  25. with tf.Graph().as_default():
  26. BATCH_SIZE = 1 # 因为只读取一副图片 所以batch 设置为1
  27. N_CLASSES = 2 # 2个输出神经元,[1,0] 或者 [0,1]猫和狗的概率
  28. # 转化图片格式
  29. image = tf.cast(image_array, tf.float32)
  30. # 图片标准化
  31. image = tf.image.per_image_standardization(image)
  32. # 图片原来是三维的 [208, 208, 3] 重新定义图片形状 改为一个4D 四维的 tensor
  33. image = tf.reshape(image, [1, 208, 208, 3])
  34. logit = model.inference(image, BATCH_SIZE, N_CLASSES)
  35. # 因为 inference 的返回没有用激活函数,所以在这里对结果用softmax 激活
  36. logit = tf.nn.softmax(logit)
  37. # 用最原始的输入数据的方式向模型输入数据 placeholder
  38. x = tf.placeholder(tf.float32, shape=[208, 208, 3])
  39. # 我门存放模型的路径
  40. logs_train_dir = '/Users/yangyibo/GitWork/pythonLean/AI/猫狗识别/saveNet/'
  41. # 定义saver
  42. saver = tf.train.Saver()
  43. with tf.Session() as sess:
  44. print("从指定的路径中加载模型。。。。")
  45. # 将模型加载到sess 中
  46. ckpt = tf.train.get_checkpoint_state(logs_train_dir)
  47. if ckpt and ckpt.model_checkpoint_path:
  48. global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
  49. saver.restore(sess, ckpt.model_checkpoint_path)
  50. print('模型加载成功, 训练的步数为 %s' % global_step)
  51. else:
  52. print('模型加载失败,,,文件没有找到')
  53. # 将图片输入到模型计算
  54. prediction = sess.run(logit, feed_dict={x: image_array})
  55. # 获取输出结果中最大概率的索引
  56. max_index = np.argmax(prediction)
  57. if max_index==0:
  58. print('猫的概率 %.6f' %prediction[:, 0])
  59. else:
  60. print('狗的概率 %.6f' %prediction[:, 1])
  61. # 测试
  62. evaluate_one_image()

/Users/yangyibo/GitWork/pythonLean/AI/猫狗识别/testImg/ 存放的是我从百度下载的猫狗图片

执行结果:

因为从testimg 中选取图片是随机的,所以每次执行的结果不同

从指定的路径中加载模型。。。。
模型加载成功, 训练的步数为 11999
狗的概率 0.964047
[Finished in 6.8s]

代码地址:https://github.com/527515025/My-TensorFlow-tutorials/blob/master/猫狗识别/evaluateCatOrDog.py

欢迎star。

总结

以上就是这篇文章的全部内容了,希望本文的内容对大家的学习或者工作具有一定的参考学习价值,谢谢大家对w3xue的支持。如果你想了解更多相关内容请查看下面相关链接

 友情链接:直通硅谷  点职佳  北美留学生论坛

本站QQ群:前端 618073944 | Java 606181507 | Python 626812652 | C/C++ 612253063 | 微信 634508462 | 苹果 692586424 | C#/.net 182808419 | PHP 305140648 | 运维 608723728

W3xue 的所有内容仅供测试,对任何法律问题及风险不承担任何责任。通过使用本站内容随之而来的风险与本站无关。
关于我们  |  意见建议  |  捐助我们  |  报错有奖  |  广告合作、友情链接(目前9元/月)请联系QQ:27243702 沸活量
皖ICP备17017327号-2 皖公网安备34020702000426号