Advertisement
Guest User

Untitled

a guest
Oct 15th, 2019
107
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.88 KB | None | 0 0
  1. import tensorflow as tf
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
  5. from tensorflow.keras import Model
  6.  
  7.  
  8. class myCallback(tf.keras.callbacks.Callback):
  9. def on_epoch_end(self, epoch, logs={}):
  10. if logs.get('accuracy') > 0.90:
  11. print("\n90% accuracy reached and stopping training for now")
  12. self.model.stop_training = True
  13.  
  14.  
  15. class_names = ['negative', 'benign_calcification', 'benign_mass', 'malignant_calcification', 'malignant_mass']
  16.  
  17. train_path_files = ['training10_0/training10_0.tfrecords',
  18. 'training10_1/training10_1.tfrecords',
  19. 'training10_2/training10_2.tfrecords']
  20.  
  21. val_path_file = ['training10_3/training10_3.tfrecords']
  22.  
  23. test_path_file = ['training10_4/training10_4.tfrecords']
  24.  
  25. extracted_train_data = tf.data.TFRecordDataset(train_path_files)
  26. extracted_val_data = tf.data.TFRecordDataset(val_path_file)
  27. extracted_test_data = tf.data.TFRecordDataset(test_path_file)
  28.  
  29. feature_description = {
  30. 'label': tf.io.FixedLenFeature([], tf.int64, default_value=0),
  31. 'label_normal': tf.io.FixedLenFeature([], tf.int64, default_value=0),
  32. 'image': tf.io.FixedLenFeature([], tf.string, default_value='')
  33. }
  34.  
  35.  
  36. def decode(serialized_example):
  37. feature = tf.io.parse_single_example(serialized_example, feature_description)
  38.  
  39. # 2. Convert the data
  40. image = tf.io.decode_raw(feature['image'], tf.uint8)
  41. label = feature['label']
  42. # 3. reshape
  43. image = tf.reshape(image, [-1, 299, 299, 1])
  44. image = tf.cast(image, tf.float32)
  45. return image, label
  46.  
  47.  
  48. def _parse_function(example_proto):
  49. return tf.io.parse_single_example(example_proto, feature_description)
  50.  
  51.  
  52. # 44707 images total
  53. parsed_training_data = extracted_train_data.map(decode)
  54. parsed_val_data = extracted_val_data.map(decode)
  55. parsed_testing_data = extracted_test_data.map(decode)
  56.  
  57. #batch_size = 32
  58. #parsed_training_data = parsed_training_data.batch(batch_size).repeat()
  59. #parsed_val_data = parsed_val_data.batch(batch_size).repeat()
  60. #parsed_testing_data = parsed_testing_data.batch(batch_size).repeat()
  61.  
  62. callback = myCallback()
  63.  
  64. model = tf.keras.models.Sequential([
  65. tf.keras.layers.Conv2D(128, (3, 3), activation='relu', input_shape=(299, 299, 1)),
  66. tf.keras.layers.MaxPool2D(2, 2),
  67. tf.keras.layers.Flatten(),
  68. tf.keras.layers.Dense(128, activation='relu'),
  69. tf.keras.layers.Dropout(0.2),
  70. tf.keras.layers.Dense(5, activation='softmax')
  71. ])
  72.  
  73. model.compile(optimizer='adam',
  74. loss='sparse_categorical_crossentropy',
  75. metrics=['accuracy'])
  76.  
  77. # verbose is the progress bar when training
  78. history = model.fit(
  79. parsed_training_data,
  80. steps_per_epoch=10,
  81. shuffle=True,
  82. validation_data=parsed_val_data,
  83. validation_steps=2,
  84. epochs=5,
  85. verbose=2,
  86. callbacks=[callback]
  87. )
  88.  
  89. print('\nhistory dict:', history.history)
  90.  
  91. print('\n# Evaluate on test data')
  92. results = model.evaluate(parsed_testing_data, steps=1)
  93. print('test loss, test acc:', results)
  94.  
  95. for image, label in parsed_testing_data.take(5):
  96. predictions = model.predict(image.numpy())
  97. image = tf.reshape(image, [299,299])
  98. plt.imshow(image.numpy(), cmap=plt.cm.binary)
  99. plt.xlabel('True Value: %s,\n Predicted Values [%0.2f, %0.2f, %0.2f, %0.2f, %0.2f]' % (class_names[label.numpy()],
  100. predictions[0, 0],
  101. predictions[0, 1],
  102. predictions[0, 2],
  103. predictions[0, 3],
  104. predictions[0, 4]
  105. ))
  106. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement