Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- train_datagen = ImageDataGenerator(rescale=1. / 255)
- train_generator = train_datagen.flow_from_directory(
- train_data_dir,
- target_size=(img_width, img_height),
- batch_size=batch_size,
- color_mode='rgb',
- class_mode=None)
- validation_generator = test_datagen.flow_from_directory(
- validation_data_dir,
- target_size=(img_width, img_height),
- batch_size=batch_size,
- color_mode='rgb',
- class_mode=None)
- input_img = Input(batch_shape=(None, img_width, img_width, 3))
- #Encoder model
- x = Conv2D(16, (3, 3), activation='relu', padding='same')(input_img)
- x = MaxPooling2D((2, 2), padding='same')(x)
- x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
- x = MaxPooling2D((2, 2), padding='same')(x)
- x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
- encoded = MaxPooling2D((2, 2), padding='same')(x)
- #decoder model
- x = Conv2D(8, (3, 3), activation='relu', padding='same')(encoded)
- x = UpSampling2D((2, 2))(x)
- x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
- x = UpSampling2D((2, 2))(x)
- x = Conv2D(16, (3, 3), activation='relu')(x)
- x = UpSampling2D((2, 2))(x)
- decoded = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)
- autoencoder = Model(input_img, decoded)
- autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')
- autoencoder.summary()
- autoencoder.fit_generator(
- fixed_generator(train_generator),
- steps_per_epoch=nb_train_samples // batch_size,
- epochs=nb_epoch,
- validation_data=fixed_generator(validation_generator),
- validation_steps=nb_validation_samples // batch_size)
- autoencoder.save_weights('anomaly-detection.h5')
- #Getting a random image for prediction
- im = cv2.resize(cv2.imread(filePath), (224, 224, 3))
- im = im * 1. / 255
- test_image[0, :, :, :] = im;
- dec = autoencoder.predict(test_image) # Decoded image
- test_image = img[0]
- dec = dec[0]
- test_image = (test_image * 255).astype('uint8')
- dec = (dec * 255).astype('uint8')
- mse_value = mse(dec, test_image)
- if mse_value > 2.74 : #Experimental threshold based on few iterations
- print('Image has an anomaly inside')
Add Comment
Please, Sign In to add comment