Advertisement
Guest User

Untitled

a guest
Feb 6th, 2018
95
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.43 KB | None | 0 0
  1. import logging
  2. import os
  3. import sys
  4. import time
  5. from optparse import OptionParser
  6.  
  7. import psycopg2
  8. from psycopg2.extras import DictCursor
  9.  
  10. sys.path.append(os.getcwd())
  11.  
  12. import fastText as ft
  13. import pandas as pd
  14. from pymongo import MongoClient
  15. import yaml
  16. from inference.inference import HierarchicalModel
  17. from inference.models import ProductMeta
  18. from dataset_tools.push_labels_into_file import load_label_into_idx
  19.  
  20. FORMAT = '%(asctime)-15s %(message)s'
  21. logging.basicConfig(format=FORMAT, level=logging.INFO)
  22. LOGGER = logging.getLogger(__name__)
  23.  
  24. """
  25.    Inference model
  26.    Fetch all textual information of product includes title, description, product tags, product type, options of product
  27. """
  28. if __name__ == '__main__':
  29.  
  30.     parser = OptionParser()
  31.     parser.add_option("-c", "--config_path", dest="config_path",
  32.                       help="Config path")
  33.     parser.add_option("-f", "--file_path", dest="file_path",
  34.                       help="file path for inference")
  35.  
  36.     (options, args) = parser.parse_args()
  37.     config_path = options.config_path
  38.     file_path = options.file_path
  39.  
  40.     FILL_EMPTY = " "
  41.  
  42.     config = yaml.load(open(config_path))
  43.  
  44.     hierarchical_model = HierarchicalModel(model_lv1_path=config['model_lv1_path'],
  45.                                            model_lv2_path=config['model_lv2_path'],
  46.                                            model_load_fn=ft.load_model,
  47.                                            label_lv1_path=config['label_lv1_path'],
  48.                                            label_lv2_path=config['label_lv2_path'],
  49.                                            prefix=config['prefix'])
  50.  
  51.     POSTGRESQL_HOST = config['POSTGRESQL']['HOST']
  52.     POSTGRESQL_PORT = config['POSTGRESQL']['PORT']
  53.     POSTGRESQL_DBNAME = config['POSTGRESQL']['DBNAME']
  54.     POSTGRESQL_USERNAME = config['POSTGRESQL']['USERNAME']
  55.     POSTGRESQL_PASSWORD = config['POSTGRESQL']['PASSWORD']
  56.  
  57.     # Connect to postgresql
  58.  
  59.     # - *dbname*: the database name
  60.     # - *database*: the database name (only as keyword argument)
  61.     # - *user*: user name used to authenticate
  62.     # - *password*: password used to authenticate
  63.     # - *host*: database host address (defaults to UNIX socket if not provided)
  64.     # - *port*: connection port number (defaults to 5432 if not provided)
  65.  
  66.  
  67.     conn = psycopg2.connect(database=POSTGRESQL_DBNAME,
  68.                             user=POSTGRESQL_USERNAME,
  69.                             password=POSTGRESQL_PASSWORD,
  70.                             port=POSTGRESQL_PORT)
  71.  
  72.  
  73.     LOGGER.info("Connect to postgresql")
  74.     cur = conn.cursor(cursor_factory=DictCursor)
  75.  
  76.  
  77.     LOGGER.info("Loading model")
  78.     hierarchical_model.load()
  79.  
  80.     labels_mapper = load_label_into_idx(config['label_lv2_path'])
  81.  
  82.     # Production mongodb
  83.     client_production = MongoClient(host=config['MONGODB_PRODUCTION']['HOST'],
  84.                                     port=config['MONGODB_PRODUCTION']['PORT'])
  85.     db_product = client_production['beeketing']['Product']
  86.  
  87.     labels_dat = []
  88.     LOGGER.info("Begin to process")
  89.  
  90.     items = cur.execute("SELECT shop_id, product_id from reporting.top_10_products")
  91.     for idx, item in enumerate(items.fetch_all()):
  92.  
  93.         shop_id = int(float(item['shop_id']))
  94.         product_id = item['product_id']
  95.  
  96.         stime = time.time()
  97.         try:
  98.             product = db_product.find_one({"_id": product_id})
  99.             product_meta = ProductMeta(title=product.get("title", FILL_EMPTY),
  100.                                        description=product.get("shortDescription", FILL_EMPTY),
  101.                                        tags=product.get("tags", FILL_EMPTY),
  102.                                        options=product.get("options", FILL_EMPTY))
  103.  
  104.             label, prob = hierarchical_model.predict(product_meta, threshold=0.6, k_top_2=1)[0]
  105.             # Write to file
  106.             labels_dat.append([shop_id, product_id, labels_mapper.get(label), prob])
  107.             etime = time.time()
  108.             LOGGER.info("Finish to inference: %s" % (etime - stime))
  109.         except Exception as e:
  110.             LOGGER.error("Error %s on shop %s and %s", e, shop_id, product_id)
  111.             with open("data/shop_inference_error.txt", "a") as f:
  112.                 f.write("%s|%s\n" % (shop_id, product_id))
  113.  
  114.     dat = pd.DataFrame(labels_dat, columns=['shop_id', 'product_id', 'label', 'prob'])
  115.     dat.to_csv("data/beeketing_top_10.csv", sep="|", index=None)
  116.  
  117.     LOGGER.info("Finished")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement