Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import logging
- import os
- import sys
- import time
- from optparse import OptionParser
- import psycopg2
- from psycopg2.extras import DictCursor
- sys.path.append(os.getcwd())
- import fastText as ft
- import pandas as pd
- from pymongo import MongoClient
- import yaml
- from inference.inference import HierarchicalModel
- from inference.models import ProductMeta
- from dataset_tools.push_labels_into_file import load_label_into_idx
- FORMAT = '%(asctime)-15s %(message)s'
- logging.basicConfig(format=FORMAT, level=logging.INFO)
- LOGGER = logging.getLogger(__name__)
- """
- Inference model
- Fetch all textual information of product includes title, description, product tags, product type, options of product
- """
- if __name__ == '__main__':
- parser = OptionParser()
- parser.add_option("-c", "--config_path", dest="config_path",
- help="Config path")
- parser.add_option("-f", "--file_path", dest="file_path",
- help="file path for inference")
- (options, args) = parser.parse_args()
- config_path = options.config_path
- file_path = options.file_path
- FILL_EMPTY = " "
- config = yaml.load(open(config_path))
- hierarchical_model = HierarchicalModel(model_lv1_path=config['model_lv1_path'],
- model_lv2_path=config['model_lv2_path'],
- model_load_fn=ft.load_model,
- label_lv1_path=config['label_lv1_path'],
- label_lv2_path=config['label_lv2_path'],
- prefix=config['prefix'])
- POSTGRESQL_HOST = config['POSTGRESQL']['HOST']
- POSTGRESQL_PORT = config['POSTGRESQL']['PORT']
- POSTGRESQL_DBNAME = config['POSTGRESQL']['DBNAME']
- POSTGRESQL_USERNAME = config['POSTGRESQL']['USERNAME']
- POSTGRESQL_PASSWORD = config['POSTGRESQL']['PASSWORD']
- # Connect to postgresql
- # - *dbname*: the database name
- # - *database*: the database name (only as keyword argument)
- # - *user*: user name used to authenticate
- # - *password*: password used to authenticate
- # - *host*: database host address (defaults to UNIX socket if not provided)
- # - *port*: connection port number (defaults to 5432 if not provided)
- conn = psycopg2.connect(database=POSTGRESQL_DBNAME,
- user=POSTGRESQL_USERNAME,
- password=POSTGRESQL_PASSWORD,
- port=POSTGRESQL_PORT)
- LOGGER.info("Connect to postgresql")
- cur = conn.cursor(cursor_factory=DictCursor)
- LOGGER.info("Loading model")
- hierarchical_model.load()
- labels_mapper = load_label_into_idx(config['label_lv2_path'])
- # Production mongodb
- client_production = MongoClient(host=config['MONGODB_PRODUCTION']['HOST'],
- port=config['MONGODB_PRODUCTION']['PORT'])
- db_product = client_production['beeketing']['Product']
- labels_dat = []
- LOGGER.info("Begin to process")
- items = cur.execute("SELECT shop_id, product_id from reporting.top_10_products")
- for idx, item in enumerate(items.fetch_all()):
- shop_id = int(float(item['shop_id']))
- product_id = item['product_id']
- stime = time.time()
- try:
- product = db_product.find_one({"_id": product_id})
- product_meta = ProductMeta(title=product.get("title", FILL_EMPTY),
- description=product.get("shortDescription", FILL_EMPTY),
- tags=product.get("tags", FILL_EMPTY),
- options=product.get("options", FILL_EMPTY))
- label, prob = hierarchical_model.predict(product_meta, threshold=0.6, k_top_2=1)[0]
- # Write to file
- labels_dat.append([shop_id, product_id, labels_mapper.get(label), prob])
- etime = time.time()
- LOGGER.info("Finish to inference: %s" % (etime - stime))
- except Exception as e:
- LOGGER.error("Error %s on shop %s and %s", e, shop_id, product_id)
- with open("data/shop_inference_error.txt", "a") as f:
- f.write("%s|%s\n" % (shop_id, product_id))
- dat = pd.DataFrame(labels_dat, columns=['shop_id', 'product_id', 'label', 'prob'])
- dat.to_csv("data/beeketing_top_10.csv", sep="|", index=None)
- LOGGER.info("Finished")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement