Advertisement
Guest User

Untitled

a guest
Oct 19th, 2019
116
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.04 KB | None | 0 0
  1. from __future__ import print_function
  2. import os
  3. import io
  4. import boto3
  5. import base64
  6. import numpy as np
  7. import zipfile
  8. from PIL import Image
  9. from tflite_runtime.interpreter import Interpreter
  10.  
  11. bkt = os.environ['bkt']
  12. direc = os.environ['dir']
  13.  
  14. s3 = boto3.resource('s3')
  15. bucket = s3.Bucket(bkt)
  16.  
  17. get_last_modified = lambda obj: int(obj.last_modified.strftime('%s'))
  18.  
  19. def get_new_model(prefix='mdls'):
  20. """
  21. Returns a list of file paths for model and label file
  22. """
  23. unsorted = [file for file in bucket.objects.filter(Prefix=prefix)]
  24. mdl = sorted(unsorted, key=get_last_modified, reverse=True)[0]
  25. mdl_path = '/tmp/' + mdl.key.split('/')[-1]
  26. bucket.download_file(mdl.key, mdl_path)
  27. with zipfile.ZipFile(mdl_path,"r") as zip_ref:
  28. zip_ref.extractall('/tmp/')
  29. paths = os.listdir('/tmp/')
  30. return ['/tmp/'+ x for x in paths]
  31.  
  32. # Download model + labels from S3
  33. mdl_path = get_new_model(prefix=direc)
  34. label_path = [x for x in mdl_path if x.endswith('txt')][0]
  35. mdl_path = [x for x in mdl_path if x.endswith('tflite')][0]
  36.  
  37. # Load labels
  38. with open(label_path, "r") as fl:
  39. labels = fl.readlines()
  40. labels = [x.replace('\n', '') for x in labels]
  41. label_dict = dict(zip(range(0,len(labels)), labels))
  42.  
  43. # Load model
  44. interpreter = Interpreter(model_path=mdl_path)
  45. interpreter.allocate_tensors()
  46.  
  47. input_details = interpreter.get_input_details()
  48. output_details = interpreter.get_output_details()
  49.  
  50. def lambda_handler(event, context):
  51. for record in event['Records']:
  52. #Kinesis data is base64 encoded so decode here
  53. payload = base64.b64decode(record['kinesis']['data'])
  54. image = Image.open(io.BytesIO(payload))
  55. image = image.resize((input_details[0]['shape'][1],input_details[0]['shape'][2]))
  56. data = np.expand_dims(np.asarray(image).astype(input_details[0]['dtype'])[:, :, :3], axis=0) / 255
  57. interpreter.set_tensor(input_details[0]['index'], data)
  58. interpreter.invoke()
  59. result = interpreter.get_tensor(output_details[0]['index'])
  60. print("Prediction: " + str(label_dict[np.argmax(result)]))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement