You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
59 lines
1.6 KiB
Python
59 lines
1.6 KiB
Python
6 years ago
|
|
||
|
# Part 3 - Making new predictions
|
||
|
import numpy as np
|
||
|
from PIL import Image, ImageDraw, ImageFont
|
||
|
from keras.preprocessing import image
|
||
|
from keras.models import model_from_yaml
|
||
|
from keras.preprocessing.image import ImageDataGenerator
|
||
|
import argparse
|
||
|
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument('images', nargs="*", help="images to classify")
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
# load YAML and create model
|
||
|
yaml_file = open('model.yaml', 'r')
|
||
|
loaded_model_yaml = yaml_file.read()
|
||
|
yaml_file.close()
|
||
|
classifier = model_from_yaml(loaded_model_yaml)
|
||
|
# load weights into new model
|
||
|
classifier.load_weights("model.h5")
|
||
|
print("Loaded model from disk")
|
||
|
|
||
|
|
||
|
for f in args.images:
|
||
|
from keras.preprocessing import image
|
||
|
test_image = image.load_img(f, target_size = (64, 64))
|
||
|
test_image = image.img_to_array(test_image)
|
||
|
test_image = np.expand_dims(test_image, axis = 0)
|
||
|
result = classifier.predict(test_image)
|
||
|
print(result)
|
||
|
|
||
|
|
||
|
|
||
|
#WHAT ARE YOUR CLASSES?
|
||
|
if result[0][0] == 1:
|
||
|
prediction = 'rect'
|
||
|
else:
|
||
|
prediction = 'circle'
|
||
|
|
||
|
print("PREDICTION: {}".format(prediction))
|
||
|
|
||
|
|
||
|
#WRITE RESULT TO IMAGE
|
||
|
image = Image.open(f)
|
||
|
width, height = image.size
|
||
|
size = (width, height+100)
|
||
|
layer = Image.new('RGB', size, (255,255,255))
|
||
|
layer.paste(image, (0,0))
|
||
|
|
||
|
draw = ImageDraw.Draw(layer)
|
||
|
font = ImageFont.truetype('Roboto-Regular.ttf', size=45)
|
||
|
(x, y) = (50, height+20)
|
||
|
message = prediction
|
||
|
color = 'rgb(0, 0, 0)' # black color
|
||
|
draw.text((x, y), message, fill=color, font=font)
|
||
|
|
||
|
|
||
|
layer.save("{}.predicted.png".format(f))
|