JobPlus知识库 IT 工业智能4.0 文章
Tensorflow— 使用inception-v3做各种图像的识别

代码:

[python] 

  1. import tensorflow as tf  
  2. import os  
  3. import numpy as np  
  4. import re  
  5. from PIL import Image  
  6. import matplotlib.pyplot as plt  

代码:

[python] 

  1. class NodeLookup(object):  
  2.     def __init__(self):    
  3.         label_lookup_path = 'inception_model/imagenet_2012_challenge_label_map_proto.pbtxt'     
  4.         uid_lookup_path = 'inception_model/imagenet_synset_to_human_label_map.txt'  
  5.         self.node_lookup = self.load(label_lookup_path, uid_lookup_path)  
  6.   
  7.     def load(self, label_lookup_path, uid_lookup_path):  
  8.         # 加载分类字符串n********对应分类名称的文件  
  9.         proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()  
  10.         uid_to_human = {}  
  11.         #一行一行读取数据  
  12.         for line in proto_as_ascii_lines :  
  13.             #去掉换行符  
  14.             line=line.strip('\n')  
  15.             #按照'\t'分割  
  16.             parsed_items = line.split('\t')  
  17.             #获取分类编号  
  18.             uid = parsed_items[0]  
  19.             #获取分类名称  
  20.             human_string = parsed_items[1]  
  21.             #保存编号字符串n********与分类名称映射关系  
  22.             uid_to_human[uid] = human_string  
  23.   
  24.         # 加载分类字符串n********对应分类编号1-1000的文件  
  25.         proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()  
  26.         node_id_to_uid = {}  
  27.         for line in proto_as_ascii:  
  28.             if line.startswith('  target_class:'):  
  29.                 #获取分类编号1-1000  
  30.                 target_class = int(line.split(': ')[1])  
  31.             if line.startswith('  target_class_string:'):  
  32.                 #获取编号字符串n********  
  33.                 target_class_string = line.split(': ')[1]  
  34.                 #保存分类编号1-1000与编号字符串n********映射关系  
  35.                 node_id_to_uid[target_class] = target_class_string[1:-2]  
  36.   
  37.         #建立分类编号1-1000对应分类名称的映射关系  
  38.         node_id_to_name = {}  
  39.         for key, val in node_id_to_uid.items():  
  40.             #获取分类名称  
  41.             name = uid_to_human[val]  
  42.             #建立分类编号1-1000到分类名称的映射关系  
  43.             node_id_to_name[key] = name  
  44.         return node_id_to_name  
  45.   
  46.     #传入分类编号1-1000返回分类名称  
  47.     def id_to_string(self, node_id):  
  48.         if node_id not in self.node_lookup:  
  49.             return ''  
  50.         return self.node_lookup[node_id]  
  51.   
  52.   
  53. #创建一个图来存放google训练好的模型  
  54. with tf.gfile.FastGFile('inception_model/classify_image_graph_def.pb', 'rb') as f:  
  55.     graph_def = tf.GraphDef()  
  56.     graph_def.ParseFromString(f.read())  
  57.     tf.import_graph_def(graph_def, name='')  
  58.   
  59.   
  60. with tf.Session() as sess:  
  61.     softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')  
  62.     #遍历目录  
  63.     for root,dirs,files in os.walk('image/'):  
  64.         for file in files:  
  65.             #载入图片  
  66.             image_data = tf.gfile.FastGFile(os.path.join(root,file), 'rb').read()  
  67.             predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式  
  68.             predictions = np.squeeze(predictions)#把结果转为1维数据  
  69.   
  70.             #打印图片路径及名称  
  71.             image_path = os.path.join(root,file)  
  72.             print(image_path)  
  73.             #显示图片  
  74.             img=Image.open(image_path)  
  75.             plt.imshow(img)  
  76.             plt.axis('off')  
  77.             plt.show()  
  78.   
  79.             #排序  
  80.             top_k = predictions.argsort()[-5:][::-1]  
  81.             node_lookup = NodeLookup()  
  82.             for node_id in top_k:       
  83.                 #获取分类名称  
  84.                 human_string = node_lookup.id_to_string(node_id)  
  85.                 #获取该分类的置信度  
  86.                 score = predictions[node_id]  
  87.                 print('%s (score = %.5f)' % (human_string, score))  
  88.             print()  

运行结果:


如果觉得我的文章对您有用,请随意打赏。您的支持将鼓励我继续创作!

¥ 打赏支持
182人赞 举报
分享到
用户评价(0)

暂无评价,你也可以发布评价哦:)

扫码APP

扫描使用APP

扫码使用

扫描使用小程序