SHARE
TWEET

Untitled

a guest Oct 19th, 2019 84 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. from __future__ import print_function
  2. import os
  3. import io
  4. import boto3
  5. import base64
  6. import numpy as np
  7. import zipfile
  8. from PIL import Image
  9. from tflite_runtime.interpreter import Interpreter
  10.  
  11. bkt = os.environ['bkt']
  12. direc = os.environ['dir']
  13.  
  14. s3 = boto3.resource('s3')
  15. bucket = s3.Bucket(bkt)
  16.  
  17. get_last_modified = lambda obj: int(obj.last_modified.strftime('%s'))
  18.  
  19. def get_new_model(prefix='mdls'):
  20.     """
  21.     Returns a list of file paths for model and label file
  22.     """
  23.     unsorted = [file for file in bucket.objects.filter(Prefix=prefix)]
  24.     mdl = sorted(unsorted, key=get_last_modified, reverse=True)[0]
  25.     mdl_path = '/tmp/' + mdl.key.split('/')[-1]
  26.     bucket.download_file(mdl.key, mdl_path)
  27.     with zipfile.ZipFile(mdl_path,"r") as zip_ref:
  28.         zip_ref.extractall('/tmp/')
  29.     paths = os.listdir('/tmp/')
  30.     return ['/tmp/'+ x for x in paths]
  31.  
  32. # Download model + labels from S3
  33. mdl_path = get_new_model(prefix=direc)
  34. label_path = [x for x in mdl_path if x.endswith('txt')][0]
  35. mdl_path = [x for x in mdl_path if x.endswith('tflite')][0]
  36.  
  37. # Load labels
  38. with open(label_path, "r") as fl:
  39.     labels = fl.readlines()
  40. labels = [x.replace('\n', '') for x in labels]
  41. label_dict = dict(zip(range(0,len(labels)), labels))
  42.    
  43. # Load model
  44. interpreter = Interpreter(model_path=mdl_path)
  45. interpreter.allocate_tensors()
  46.  
  47. input_details = interpreter.get_input_details()
  48. output_details = interpreter.get_output_details()
  49.  
  50. def lambda_handler(event, context):
  51.     for record in event['Records']:
  52.         #Kinesis data is base64 encoded so decode here
  53.         payload = base64.b64decode(record['kinesis']['data'])
  54.         image = Image.open(io.BytesIO(payload))
  55.         image = image.resize((input_details[0]['shape'][1],input_details[0]['shape'][2]))
  56.         data = np.expand_dims(np.asarray(image).astype(input_details[0]['dtype'])[:, :, :3], axis=0) / 255
  57.         interpreter.set_tensor(input_details[0]['index'], data)
  58.         interpreter.invoke()
  59.         result = interpreter.get_tensor(output_details[0]['index'])
  60.         print("Prediction: " + str(label_dict[np.argmax(result)]))
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top