Advertisement
daniilak

Untitled

May 7th, 2021
181
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.93 KB | None | 0 0
  1. import imghdr
  2. import os
  3. import glob
  4. from flask import Flask, render_template, request, redirect, url_for, abort, send_from_directory, jsonify
  5. from werkzeug.utils import secure_filename
  6. from utils.utils import calc_mean_score, save_json
  7. from handlers.model_builder import Nima
  8. from handlers.data_generator import TestDataGenerator
  9.  
  10. app = Flask(__name__)
  11.  
  12. def predict(model, data_generator):
  13.     return model.predict_generator(data_generator, workers=8, use_multiprocessing=True, verbose=1)
  14.  
  15. def image_file_to_json(img_path):
  16.     img_dir = os.path.dirname(img_path)
  17.     img_id = os.path.basename(img_path).split('.')[0]
  18.     return img_dir, [{'image_id': img_id}]
  19.  
  20. def image_dir_to_json(img_dir, img_type='jpg'):
  21.     img_paths = glob.glob(os.path.join(img_dir, '*.'+img_type))
  22.     samples = []
  23.     for img_path in img_paths:
  24.         img_id = os.path.basename(img_path).split('.')[0]
  25.         samples.append({'image_id': img_id})
  26.  
  27.     return samples
  28.  
  29. nima1 = Nima("MobileNet", weights=None)
  30. nima1.build()
  31. nima1.nima_model.load_weights("MobileNet/weights_mobilenet_aesthetic_0.07.hdf5")
  32.  
  33. nima2 = Nima("MobileNet", weights=None)
  34. nima2.build()
  35. nima2.nima_model.load_weights("MobileNet/weights_mobilenet_technical_0.11.hdf5")
  36.  
  37. image_dir = 'test_images'
  38.  
  39. app.config['MAX_CONTENT_LENGTH'] = 1024 * 1024
  40. app.config['UPLOAD_EXTENSIONS'] = ['.jpg']
  41. app.config['UPLOAD_PATH'] = 'uploads'
  42.  
  43. def validate_image(stream):
  44.     header = stream.read(512)  # 512 bytes should be enough for a header check
  45.     stream.seek(0)  # reset stream pointer
  46.     format = imghdr.what(None, header)
  47.     if not format:
  48.         return None
  49.     return '.' + (format if format != 'jpeg' else 'jpg')
  50.  
  51. @app.route('/c/', methods=['GET'])
  52. def index():
  53.     files = os.listdir(app.config['UPLOAD_PATH'])
  54.     return render_template('index.html', files=files)
  55.  
  56. @app.route('/c/test/<filename>', methods=['GET'])
  57. def test(filename):
  58.     filename = str(filename)
  59.     if len(filename) == 0:
  60.         abort(400)
  61.     image_dir_full = '/'+image_dir+'/'+filename
  62.     samples1 = image_dir_to_json(image_dir_full, img_type='jpg')
  63.     samples2 = samples1
  64.     # return jsonify(samples)
  65.     data_generator1 = TestDataGenerator(samples1, image_dir_full, 64, 10, nima1.preprocessing_function(),img_format='jpg')
  66.     data_generator2 = TestDataGenerator(samples2, image_dir_full, 64, 10, nima2.preprocessing_function(),img_format='jpg')
  67.     predictions1 = predict(nima1.nima_model, data_generator1)
  68.     predictions2 = predict(nima2.nima_model, data_generator2)
  69.  
  70.     for i, sample in enumerate(samples1):
  71.         sample['msp'] = calc_mean_score(predictions1[i])
  72.     for i, sample in enumerate(samples2):
  73.         sample['msp'] = calc_mean_score(predictions2[i])
  74.     del data_generator1
  75.     del data_generator2
  76.     del predictions1
  77.     del predictions2
  78.     return jsonify([samples1, samples2])
  79.  
  80. if __name__ == '__main__':
  81.    app.run(host='0.0.0.0', port=5060, threaded=False, processes=1)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement