Advertisement
NokitaKaze

Тестирую слои Keras

Nov 27th, 2021
106
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.68 KB | None | 0 0
  1. import tensorflow as tf
  2. from tensorflow.keras import Model, layers
  3.  
  4. img_w = 5
  5. img_h = 3
  6. img_c = 2
  7.  
  8.  
  9. class TestModel(Model):
  10.     # Set layers.
  11.     def __init__(self):
  12.         super(TestModel, self).__init__()
  13.  
  14.         self.single_file_pixels = img_w * img_h * img_c
  15.  
  16.         self.input_reshape = layers.Reshape((img_h, img_w, img_c))
  17.  
  18.         self.main_max_pool = layers.MaxPool2D(2, strides=2, padding='same')
  19.  
  20.         # Flatten the data to a 1-D vector for the fully connected layer.
  21.         self.flatten = layers.Flatten()
  22.  
  23.         # Fully connected layer.
  24.         self.dense_layer = layers.Dense(20)
  25.         # Apply Dropout (if is_training is False, dropout is not applied).
  26.         self.dropout = layers.Dropout(rate=0.2)
  27.  
  28.         # Output layer, class prediction.
  29.         self.dense_out = layers.Dense(2)
  30.  
  31.         self.conv2d = layers.Conv2D(2, kernel_size=2, activation=tf.nn.relu, padding='same')
  32.  
  33.     # Set forward pass.
  34.     def call(self, x, is_training=False, mask=None):
  35.         raw_items = self.input_reshape(x)
  36.  
  37.         conv1 = self.conv2d(raw_items)
  38.         conv2 = self.conv2d(conv1)
  39.         pool = self.main_max_pool(conv2)
  40.  
  41.         fc1 = self.flatten(pool)
  42.         fc2 = self.dense_layer(fc1)
  43.         do1 = self.dropout(fc2, training=is_training)
  44.         out = self.dense_out(do1)
  45.         y = out
  46.         if not is_training:
  47.             # tf cross entropy expect logits without softmax, so only
  48.             # apply softmax when not training.
  49.             y = tf.nn.softmax(y)
  50.  
  51.         return y
  52.  
  53.  
  54. with tf.device("/CPU"):
  55.     test_model = TestModel()
  56.     test_model.build(input_shape=(None, test_model.single_file_pixels))
  57.     test_model.summary()
  58.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement