Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class PredictSklearn(beam.DoFn):
- """ Format the input to the desired shape"""
- def __init__(self, project=None, bucket_name=None, model_path=None, destination_name=None):
- self._model = None
- self._project = project
- self._bucket_name = bucket_name
- self._model_path = model_path
- self._destination_name = destination_name
- # Load once or very few times
- def setup(self):
- logging.info(
- "Sklearn model initialisation {}".format(self._model_path))
- download_blob(bucket_name=self._bucket_name, source_blob_name=self._model_path,
- project=self._project, destination_file_name=self._destination_name)
- # unpickle sklearn model
- self._model = pickle.load(open(self._destination_name, 'rb'))
- def process(self, element):
- element["prediction"] = self._model.predict(element["data"])
- return [element]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement