Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # -*- coding: utf-8 -*-
- """HackathonMicrosoft.ipynb
- Automatically generated by Colaboratory.
- Original file is located at
- https://colab.research.google.com/drive/18TYVSG2W1l3Y5jB4g67hupLsvsyrD6jf
- """
- import numpy as np
- import torch
- from torch import nn
- import torch.nn.functional as F
- from torchvision import datasets, transforms, models,utils
- from flask import Flask,request,Response,jsonify
- import json
- from flask_cors import CORS
- from PIL import Image
- UPLOAD_FOLDER = '/path/to/the/uploads'
- ALLOWED_EXTENSIONS = set(['txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif'])
- app = Flask(__name__)
- CORS(app)
- model = models.densenet161(pretrained=True)
- for name in model.children():
- for child, config in name.named_children():
- print(str(child) + ' is frozen')
- for param in config.parameters():
- param.requires_grad = False
- from collections import OrderedDict
- classifier = nn.Sequential(OrderedDict([('fc1', nn.Linear(2208, 512)),
- ('relu1', nn.ReLU()),
- ('dropout1', nn.Dropout(0.3)),
- ('fc3', nn.Linear(512, 6)),
- ('output', nn.LogSoftmax(dim=1))]))
- model.classifier=classifier
- model.load_state_dict(torch.load('my_file.pt', map_location='cpu'))
- model=model.cpu()
- def transform_image(image_bytes):
- my_transforms = transforms.Compose([transforms.Resize(255),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- transforms.Normalize(
- [0.485, 0.456, 0.406],
- [0.229, 0.224, 0.225])])
- image = Image.open(io.BytesIO(image_bytes))
- return my_transforms(image).unsqueeze(0)
- def get_predict(image, topk=2):
- ''' Predict the class (or classes) of an image using a trained deep learning model.
- '''
- global model
- with torch.no_grad():
- output = model.forward(image)
- probabililty = torch.exp(output)
- prob,cla = probabililty.topk(topk)
- print(cla, 'clase')
- return prob.cpu().numpy().squeeze(), cla.cpu().numpy().squeeze()
- @app.route('/predict', methods=['POST'])
- def pred():
- if request.method == 'POST':
- # we will get the file from the request
- file = request.files['file']
- # convert that to bytes
- img_bytes = file.read()
- tensor = transform_image(image_bytes=image_bytes)
- class_id, class_name = get_prediction(image)
- return jsonify({'gg':1})
- if __name__ == "__main__":
- app.run(debug=True, port=8000, host='0.0.0.0')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement