machine-learning icon indicating copy to clipboard operation
machine-learning copied to clipboard

关于dive_into_keras_vgg16的预测脚本问题

Open littlecangbaby opened this issue 8 years ago • 1 comments

你好,目录下的ex.py文件可以跑过,但是后面预测的那部分源码在哪里,感觉看的不是很懂。能否放上完整的脚本供参考,非常感谢。

现在我们开始来预测了

首先写一个方法来加载并处理图片

def load_image(imageurl):
    im = cv2.resize(cv2.imread(imageurl),(224,224)).astype(np.float32)
    im[:,:,0] -= 103.939
    im[:,:,1] -= 116.779
    im[:,:,2] -= 123.68
    im = im.transpose((2,0,1))
    im = np.expand_dims(im,axis=0)
    return im

读取vgg16的类别文件

f = open('synset_words.txt','r')
lines = f.readlines()
f.close()
def predict(url):
    im = load_image(url)
    pre = np.argmax(model.predict(im))
    print lines[pre]
%pylab inline
Populating the interactive namespace from numpy and matplotlib
from IPython.display import Image
Image('cat1.jpg')

littlecangbaby avatar Mar 07 '17 08:03 littlecangbaby

def load_image(url): im = cv2.resize(cv2.imread(url), (224, 224)).astype(np.float32) im[:,:,0] -= 103.939 im[:,:,1] -= 116.779 im[:,:,2] -= 123.68 im = im.transpose((2,0,1)) im = np.expand_dims(im, axis=0) return im

def predict(url): # 预测 f = open('synset_words.txt', 'r') lines = f.readlines() im = load_image(url) pre = np.argmax(model.predict(im)) print(lines[pre]) f.close()

if name == "main":

# Test pretrained model
model = VGG_16('vgg16_weights.h5')
sgd = SGD(lr=1e-6, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(optimizer=sgd, loss='categorical_crossentropy')
predict('cat.1.jpg')

Ontheroad123 avatar Aug 16 '18 07:08 Ontheroad123