Advertisement
jsprieto10

Untitled

Sep 24th, 2019
71
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.98 KB | None | 0 0
  1. # -*- coding: utf-8 -*-
  2. """HackathonMicrosoft.ipynb
  3.  
  4. Automatically generated by Colaboratory.
  5.  
  6. Original file is located at
  7.    https://colab.research.google.com/drive/18TYVSG2W1l3Y5jB4g67hupLsvsyrD6jf
  8. """
  9.  
  10. import numpy as np
  11. import torch
  12. from torch import nn
  13. import torch.nn.functional as F
  14. from torchvision import datasets, transforms, models,utils
  15. from flask import Flask,request,Response,jsonify
  16. import json
  17. from flask_cors import CORS
  18. from PIL import Image
  19. import io
  20.  
  21. UPLOAD_FOLDER = '/path/to/the/uploads'
  22. ALLOWED_EXTENSIONS = set(['txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif'])
  23. app = Flask(__name__)
  24. CORS(app)
  25.  
  26.  
  27.  
  28. model = models.densenet161(pretrained=True)
  29.  
  30. for name in model.children():
  31.   for child, config in name.named_children():
  32.     print(str(child) + ' is frozen')
  33.     for param in config.parameters():
  34.           param.requires_grad = False
  35.  
  36. from collections import OrderedDict
  37. classifier = nn.Sequential(OrderedDict([('fc1', nn.Linear(2208, 512)),
  38.                                         ('relu1', nn.ReLU()),
  39.                                         ('dropout1', nn.Dropout(0.3)),
  40.                                          ('fc3', nn.Linear(512, 2)),
  41.                                         ('output', nn.LogSoftmax(dim=1))]))
  42.  
  43. model.classifier=classifier
  44.  
  45.  
  46.  
  47. #model.load_state_dict(torch.load('my_file.pt', map_location='cpu'))
  48.  
  49.  
  50. model=model.cpu()
  51.  
  52.  
  53. def transform_image(image_bytes):
  54.     my_transforms = transforms.Compose([transforms.Resize(255),
  55.                                         transforms.CenterCrop(224),
  56.                                         transforms.ToTensor(),
  57.                                         transforms.Normalize(
  58.                                             [0.485, 0.456, 0.406],
  59.                                             [0.229, 0.224, 0.225])])
  60.     image = Image.open(io.BytesIO(image_bytes))
  61.     return my_transforms(image).unsqueeze(0)
  62.  
  63.  
  64.  
  65. def get_prediction(image, topk=2):
  66.     ''' Predict the class (or classes) of an image using a trained deep learning model.
  67.    '''
  68.     global model
  69.     with torch.no_grad():
  70.         output = model.forward(image)
  71.         probabililty = torch.exp(output)
  72.        
  73.    
  74.     prob,cla = probabililty.topk(topk)
  75.     print(cla, 'clase')
  76.    
  77.     return prob.cpu().numpy().squeeze(), cla.cpu().numpy().squeeze()
  78.    
  79.  
  80. @app.route('/predict', methods=['POST'])
  81. def pred():
  82.     if request.method == 'POST':
  83.         # we will get the file from the request
  84.         file = request.files['file']
  85.         # convert that to bytes
  86.         img_bytes = file.read()
  87.         tensor = transform_image(image_bytes=img_bytes)
  88.         pro, cla = get_prediction(tensor)
  89.         print(pro,cla)
  90.        
  91.         l=[]
  92.         for p, c in zip(pro,cla):
  93.             d={}
  94.             m=["bueno","malo"]
  95.             d['clase']= m[c]
  96.             d['probabilidad']=str(p)
  97.             l.append(d)
  98.         return jsonify(l)
  99.  
  100.  
  101.  
  102.  
  103. if __name__ == "__main__":
  104.     app.run(debug=True, port=8000, host='0.0.0.0')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement