Advertisement
Guest User

Untitled

a guest
Mar 20th, 2019
65
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.03 KB | None | 0 0
  1. import tensorflow as tf
  2. import numpy as np
  3. from tensorflow.keras.applications import MobileNetV2
  4.  
  5.  
  6. tf.enable_eager_execution()
  7.  
  8.  
  9. class ContextExtractor(tf.keras.Model):
  10.     def __init__(self, model_name, pretrained_shape):
  11.         super().__init__()
  12.         self.model = self.__get_model(model_name, pretrained_shape)
  13.  
  14.     def call(self, x, training=False, **kwargs):
  15.         print(training)
  16.         features = self.model(x, training=training)
  17.         return features
  18.  
  19.     def __get_model(self, model_name, pretrained_shape):
  20.         if model_name == "mobilenetv2":
  21.             return MobileNetV2(
  22.                 input_shape=pretrained_shape,
  23.                 weights="imagenet",
  24.                 alpha=0.35,
  25.                 include_top=False,
  26.                 pooling="avg",
  27.             )
  28.  
  29.  
  30. context_extractor = ContextExtractor("mobilenetv2", (224, 224, 3))
  31. bc = tf.random.uniform((10, 224, 224, 3))
  32.  
  33. for i in range(1, 10):
  34.     a1 = context_extractor(bc[:i], training=False)
  35.     print(np.linalg.norm(a1[0].numpy()))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement