Advertisement
jsprieto10

Untitled

Sep 24th, 2019
99
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.77 KB | None | 0 0
  1. # -*- coding: utf-8 -*-
  2. """HackathonMicrosoft.ipynb
  3.  
  4. Automatically generated by Colaboratory.
  5.  
  6. Original file is located at
  7.    https://colab.research.google.com/drive/18TYVSG2W1l3Y5jB4g67hupLsvsyrD6jf
  8. """
  9.  
  10. import numpy as np
  11. import torch
  12. from torch import nn
  13. import torch.nn.functional as F
  14. from torchvision import datasets, transforms, models,utils
  15. from flask import Flask,request,Response,jsonify
  16. import json
  17. from flask_cors import CORS
  18. from PIL import Image
  19.  
  20. UPLOAD_FOLDER = '/path/to/the/uploads'
  21. ALLOWED_EXTENSIONS = set(['txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif'])
  22. app = Flask(__name__)
  23. CORS(app)
  24.  
  25.  
  26.  
  27. model = models.densenet161(pretrained=True)
  28.  
  29. for name in model.children():
  30.   for child, config in name.named_children():
  31.     print(str(child) + ' is frozen')
  32.     for param in config.parameters():
  33.           param.requires_grad = False
  34.  
  35. from collections import OrderedDict
  36. classifier = nn.Sequential(OrderedDict([('fc1', nn.Linear(2208, 512)),
  37.                                         ('relu1', nn.ReLU()),
  38.                                         ('dropout1', nn.Dropout(0.3)),
  39.                                          ('fc3', nn.Linear(512, 6)),
  40.                                         ('output', nn.LogSoftmax(dim=1))]))
  41.  
  42. model.classifier=classifier
  43.  
  44.  
  45.  
  46. model.load_state_dict(torch.load('my_file.pt', map_location='cpu'))
  47.  
  48.  
  49. model=model.cpu()
  50.  
  51.  
  52. def transform_image(image_bytes):
  53.     my_transforms = transforms.Compose([transforms.Resize(255),
  54.                                         transforms.CenterCrop(224),
  55.                                         transforms.ToTensor(),
  56.                                         transforms.Normalize(
  57.                                             [0.485, 0.456, 0.406],
  58.                                             [0.229, 0.224, 0.225])])
  59.     image = Image.open(io.BytesIO(image_bytes))
  60.     return my_transforms(image).unsqueeze(0)
  61.  
  62.  
  63.  
  64. def get_predict(image, topk=2):
  65.     ''' Predict the class (or classes) of an image using a trained deep learning model.
  66.    '''
  67.     global model
  68.     with torch.no_grad():
  69.         output = model.forward(image)
  70.         probabililty = torch.exp(output)
  71.        
  72.    
  73.     prob,cla = probabililty.topk(topk)
  74.     print(cla, 'clase')
  75.    
  76.     return prob.cpu().numpy().squeeze(), cla.cpu().numpy().squeeze()
  77.    
  78.  
  79. @app.route('/predict', methods=['POST'])
  80. def pred():
  81.     if request.method == 'POST':
  82.         # we will get the file from the request
  83.         file = request.files['file']
  84.         # convert that to bytes
  85.         img_bytes = file.read()
  86.         tensor = transform_image(image_bytes=image_bytes)
  87.         class_id, class_name = get_prediction(image)
  88.         return jsonify({'gg':1})
  89.  
  90.  
  91.  
  92.  
  93. if __name__ == "__main__":
  94.     app.run(debug=True, port=8000, host='0.0.0.0')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement