Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import tensorflow as tf
- import numpy as np
- from tensorflow.keras.applications import MobileNetV2
- tf.enable_eager_execution()
- class ContextExtractor(tf.keras.Model):
- def __init__(self, model_name, pretrained_shape):
- super().__init__()
- self.model = self.__get_model(model_name, pretrained_shape)
- def call(self, x, training=False, **kwargs):
- print(training)
- features = self.model(x, training=training)
- return features
- def __get_model(self, model_name, pretrained_shape):
- if model_name == "mobilenetv2":
- return MobileNetV2(
- input_shape=pretrained_shape,
- weights="imagenet",
- alpha=0.35,
- include_top=False,
- pooling="avg",
- )
- context_extractor = ContextExtractor("mobilenetv2", (224, 224, 3))
- bc = tf.random.uniform((10, 224, 224, 3))
- for i in range(1, 10):
- a1 = context_extractor(bc[:i], training=False)
- print(np.linalg.norm(a1[0].numpy()))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement