Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from __future__ import print_function
- import os
- import io
- import boto3
- import base64
- import numpy as np
- import zipfile
- from PIL import Image
- from tflite_runtime.interpreter import Interpreter
- bkt = os.environ['bkt']
- direc = os.environ['dir']
- s3 = boto3.resource('s3')
- bucket = s3.Bucket(bkt)
- get_last_modified = lambda obj: int(obj.last_modified.strftime('%s'))
- def get_new_model(prefix='mdls'):
- """
- Returns a list of file paths for model and label file
- """
- unsorted = [file for file in bucket.objects.filter(Prefix=prefix)]
- mdl = sorted(unsorted, key=get_last_modified, reverse=True)[0]
- mdl_path = '/tmp/' + mdl.key.split('/')[-1]
- bucket.download_file(mdl.key, mdl_path)
- with zipfile.ZipFile(mdl_path,"r") as zip_ref:
- zip_ref.extractall('/tmp/')
- paths = os.listdir('/tmp/')
- return ['/tmp/'+ x for x in paths]
- # Download model + labels from S3
- mdl_path = get_new_model(prefix=direc)
- label_path = [x for x in mdl_path if x.endswith('txt')][0]
- mdl_path = [x for x in mdl_path if x.endswith('tflite')][0]
- # Load labels
- with open(label_path, "r") as fl:
- labels = fl.readlines()
- labels = [x.replace('\n', '') for x in labels]
- label_dict = dict(zip(range(0,len(labels)), labels))
- # Load model
- interpreter = Interpreter(model_path=mdl_path)
- interpreter.allocate_tensors()
- input_details = interpreter.get_input_details()
- output_details = interpreter.get_output_details()
- def lambda_handler(event, context):
- for record in event['Records']:
- #Kinesis data is base64 encoded so decode here
- payload = base64.b64decode(record['kinesis']['data'])
- image = Image.open(io.BytesIO(payload))
- image = image.resize((input_details[0]['shape'][1],input_details[0]['shape'][2]))
- data = np.expand_dims(np.asarray(image).astype(input_details[0]['dtype'])[:, :, :3], axis=0) / 255
- interpreter.set_tensor(input_details[0]['index'], data)
- interpreter.invoke()
- result = interpreter.get_tensor(output_details[0]['index'])
- print("Prediction: " + str(label_dict[np.argmax(result)]))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement