Advertisement
Guest User

Untitled

a guest
Sep 17th, 2019
106
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.18 KB | None | 0 0
  1. from tensorflow.keras import models
  2. from tensorflow.keras import layers
  3. from tensorflow.keras import optimizers
  4.  
  5. from tensorflow.keras.preprocessing import image
  6. from tensorflow.keras.preprocessing.image import ImageDataGenerator
  7.  
  8. from tensorflow.keras.applications import VGG16
  9. from tensorflow.keras.applications import vgg16
  10.  
  11. from tensorflow.keras.callbacks import ModelCheckpoint
  12.  
  13. import os
  14.  
  15.  
  16. # Create Keras model
  17. image_size = 150
  18. input_layer = layers.Input(shape=(image_size, image_size, 3), name="model_input")
  19. base_model = VGG16(weights="imagenet", include_top=False, input_tensor=input_layer)
  20. model_head = base_model.output
  21. model_head = layers.Flatten(name="model_head_flatten")(model_head)
  22. model_head = layers.Dense(256, activation="relu")(model_head)
  23. model_head = layers.Dense(2, activation="softmax")(model_head)
  24. model = models.Model(inputs=input_layer, outputs=model_head)
  25.  
  26. # Create image date generators
  27. # You need one image data folder with three sub-folders "train", "validation", "test"
  28. image_dir = "/home/mfb/Development/tf-github/data"
  29. datagen = ImageDataGenerator(preprocessing_function=vgg16.preprocess_input)
  30. training_img_generator = datagen.flow_from_directory(os.path.join(image_dir, 'train'),
  31. target_size=(image_size, image_size), batch_size=20, class_mode="categorical")
  32. validation_img_generator = datagen.flow_from_directory(os.path.join(image_dir, 'validation'),
  33. target_size=(image_size, image_size), batch_size=20, class_mode="categorical")
  34. test_img_generator = datagen.flow_from_directory(os.path.join(image_dir, 'test'),
  35. target_size=(image_size, image_size), batch_size=20, class_mode="categorical")
  36.  
  37. # Compile Keras model
  38. model.compile(loss="categorical_crossentropy", optimizer=optimizers.Adam(), metrics=["accuracy"])
  39.  
  40. # Train Keras model
  41. auto_save_path = "/home/mfb/Development/tf-github/models"
  42. checkpoint = ModelCheckpoint(auto_save_path, monitor="val_acc", verbose=0, save_best_only=True)
  43. model.fit_generator(training_img_generator,
  44. steps_per_epoch=50, epochs=25, validation_steps=50,
  45. validation_data=validation_img_generator,
  46. callbacks=[checkpoint], verbose=1)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement