从tensorflow的github克隆下来后,git checkout r1.6 然后在tensorflow目录下跑flower_photos数据集
训练命令如下:
python examples/image_retraining/retrain.py --learning_rate=0.0001 --testing_percentage=20 --validation_percentage=20 --train_batch_size=32 --validation_batch_size=-1 --flip_left_right True --random_scale=30 --random_brightness=30 --eval_step_interval=100 --how_many_training_steps=600 --architecture mobilenet_1.0_224 --image_dir ~/Documents/yolo-darkne/flower_photos
训练结束后,在/tmp/imagenet目录下得到模型文件和对应标签文件:output_graph.pb、output_labels.txt
用如下代码调用模型:
# -*- coding:utf-8 -*-
import tensorflow as tf
import sys
# 命令行参数,传入要判断的图片路径
image_file = sys.argv[1]
#print(image_file)
# 读取图像
image = tf.gfile.FastGFile(image_file, 'rb').read()
# 加载图像分类标签
labels = []
for label in tf.gfile.GFile("/tmp/imagenet/output_labels.txt"):
labels.append(label.rstrip())
# 加载Graph
with tf.gfile.FastGFile("/tmp/imagenet/output_graph.pb", 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
predict = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image})
# 根据分类概率进行排序
top = predict[0].argsort()[-len(predict[0]):][::-1]
for index in top:
human_string = labels[index]
score = predict[0][index]
print(human_string, score)
报错如下:
TypeError: Cannot interpret feed_dict key as Tensor: The name 'DecodeJpeg/contents:0' refers to a Tensor which does not exist. The operation, 'DecodeJpeg/contents', does not exist in the graph.
经查证,tensorflow1.2版本之后的版本,examples/image_retraining/retrain.py里已经没有DecodeJpeg/contents 这个tensor名称。 Google has updated their retrain.py in image_retraining. They removed the op:DecodeJpeg, so you can’t get it.
解决办法: 在tensorflow根目录,执行git checkout r1.2将版本切到1.2,重新训练,得到模型再调用,问题解决。