Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class CNNTransformer(Transformer):
- transformers = []
- CACHE_MAX_SIZE = 5
- MAX_FRAME_COUNTER = 3
- def __init__(self):
- super(CNNTransformer, self).__init__()
- self.model = None
- self.cache = []
- self.frame_counter = 0
- def clear_cache(self):
- self.cache.clear()
- self.frame_counter = 0
- def load_model(self):
- pass
- class CharPredictor(CNNTransformer):
- __previous_predictions = to_categorical(np.random.randint(len(AsciiEncoder.AVAILABLE_CHARS), size=20),
- len(AsciiEncoder.AVAILABLE_CHARS))
- def __init__(self, num_of_chars):
- super(CharPredictor, self).__init__()
- self.num_of_chars = num_of_chars
- self.model = self.load_model()
- self.out_key = 'chars'
- def transform(self, X, **transform_params):
- if len(self.cache) < CNNTransformer.CACHE_MAX_SIZE:
- if self.frame_counter == CNNTransformer.MAX_FRAME_COUNTER:
- self.cache.append(True)
- self.frame_counter = 0
- else:
- self.frame_counter += 1
- return None
- x_data = CharPredictor.__previous_predictions
- self.output = self.model.predict(np.asarray([x_data]))[0]
- self.clear_cache()
- return self.output
- def prepare_data(self):
- predictions = CharPredictor.__previous_predictions[:]
- if len(predictions) >= self.num_of_chars:
- while len(predictions) > self.num_of_chars:
- predictions.pop(0)
- else:
- while len(predictions) != self.num_of_chars:
- predictions.insert(0, 0)
- return to_categorical(predictions, len(AsciiEncoder.AVAILABLE_CHARS))
- @staticmethod
- def add_to_previous_predictions(prediction):
- CharPredictor.__previous_predictions = np.append(CharPredictor.__previous_predictions[1:],
- to_categorical(prediction, len(AsciiEncoder.AVAILABLE_CHARS)),
- axis=0)
- def load_model(self):
- with open(os.path.join(CHAR_PREDICTION_FOLDER, ARCHITECTURE_JSON_NAME), 'r') as f:
- model = model_from_json(f.read())
- model.load_weights(os.path.join(CHAR_PREDICTION_FOLDER, WEIGHTS_HDF5_NAME))
- return model
- # return CharPredictionMock.model(time_steps=self.num_of_chars, feature_length=len(AsciiEncoder.AVAILABLE_CHARS))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement