Advertisement
Guest User

Untitled

a guest
Oct 18th, 2017
65
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.49 KB | None | 0 0
  1.  
  2. class CNNTransformer(Transformer):
  3.     transformers = []
  4.  
  5.     CACHE_MAX_SIZE = 5
  6.     MAX_FRAME_COUNTER = 3
  7.  
  8.     def __init__(self):
  9.         super(CNNTransformer, self).__init__()
  10.         self.model = None
  11.         self.cache = []
  12.         self.frame_counter = 0
  13.  
  14.     def clear_cache(self):
  15.         self.cache.clear()
  16.         self.frame_counter = 0
  17.  
  18.     def load_model(self):
  19.         pass
  20.  
  21.  
  22. class CharPredictor(CNNTransformer):
  23.     __previous_predictions = to_categorical(np.random.randint(len(AsciiEncoder.AVAILABLE_CHARS), size=20),
  24.                                             len(AsciiEncoder.AVAILABLE_CHARS))
  25.  
  26.     def __init__(self, num_of_chars):
  27.         super(CharPredictor, self).__init__()
  28.         self.num_of_chars = num_of_chars
  29.         self.model = self.load_model()
  30.         self.out_key = 'chars'
  31.  
  32.     def transform(self, X, **transform_params):
  33.         if len(self.cache) < CNNTransformer.CACHE_MAX_SIZE:
  34.             if self.frame_counter == CNNTransformer.MAX_FRAME_COUNTER:
  35.                 self.cache.append(True)
  36.                 self.frame_counter = 0
  37.             else:
  38.                 self.frame_counter += 1
  39.             return None
  40.         x_data = CharPredictor.__previous_predictions
  41.         self.output = self.model.predict(np.asarray([x_data]))[0]
  42.         self.clear_cache()
  43.         return self.output
  44.  
  45.     def prepare_data(self):
  46.         predictions = CharPredictor.__previous_predictions[:]
  47.         if len(predictions) >= self.num_of_chars:
  48.             while len(predictions) > self.num_of_chars:
  49.                 predictions.pop(0)
  50.         else:
  51.             while len(predictions) != self.num_of_chars:
  52.                 predictions.insert(0, 0)
  53.         return to_categorical(predictions, len(AsciiEncoder.AVAILABLE_CHARS))
  54.  
  55.     @staticmethod
  56.     def add_to_previous_predictions(prediction):
  57.         CharPredictor.__previous_predictions = np.append(CharPredictor.__previous_predictions[1:],
  58.                                                          to_categorical(prediction, len(AsciiEncoder.AVAILABLE_CHARS)),
  59.                                                          axis=0)
  60.  
  61.     def load_model(self):
  62.         with open(os.path.join(CHAR_PREDICTION_FOLDER, ARCHITECTURE_JSON_NAME), 'r') as f:
  63.             model = model_from_json(f.read())
  64.         model.load_weights(os.path.join(CHAR_PREDICTION_FOLDER, WEIGHTS_HDF5_NAME))
  65.         return model
  66.         # return CharPredictionMock.model(time_steps=self.num_of_chars, feature_length=len(AsciiEncoder.AVAILABLE_CHARS))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement