Advertisement
Guest User

Untitled

a guest
May 7th, 2017
550
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 12.68 KB | None | 0 0
  1. package edu.wiki.search;
  2.  
  3. import java.io.BufferedReader;
  4. import java.io.ByteArrayInputStream;
  5. import java.io.DataInputStream;
  6. import java.io.IOException;
  7. import java.io.InputStream;
  8. import java.io.InputStreamReader;
  9. import java.io.StringReader;
  10. import java.sql.Connection;
  11. import java.sql.DriverManager;
  12. import java.sql.ResultSet;
  13. import java.sql.SQLException;
  14. import java.sql.Statement;
  15.  
  16. import org.apache.lucene.analysis.TokenStream;
  17. import org.apache.lucene.analysis.tokenattributes.TermAttribute;
  18.  
  19. import java.sql.PreparedStatement;
  20. import java.util.ArrayList;
  21. import java.util.Arrays;
  22. import java.util.Collection;
  23. import java.util.Collections;
  24. import java.util.Comparator;
  25. import java.util.HashMap;
  26. import java.util.Map;
  27.  
  28. import edu.wiki.api.concept.IConceptIterator;
  29. import edu.wiki.api.concept.IConceptVector;
  30. import edu.wiki.api.concept.scorer.CosineScorer;
  31. import edu.wiki.concept.ConceptVectorSimilarity;
  32. import edu.wiki.concept.TroveConceptVector;
  33. import edu.wiki.index.WikipediaAnalyzer;
  34. import edu.wiki.util.HeapSort;
  35. import gnu.trove.TIntFloatHashMap;
  36. import gnu.trove.TIntIntHashMap;
  37.  
  38. /**
  39. * Performs search on the index located in database.
  40. *
  41. * @author Cagatay Calli <ccalli@gmail.com>
  42. */
  43. public class ESASearcher {
  44. Connection connection;
  45.  
  46. PreparedStatement pstmtQuery;
  47. PreparedStatement pstmtIdfQuery;
  48. PreparedStatement pstmtLinks;
  49. Statement stmtInlink;
  50.  
  51. WikipediaAnalyzer analyzer;
  52.  
  53. String strTermQuery = "SELECT t.vector FROM idx t WHERE t.term = ?";
  54. String strIdfQuery = "SELECT t.idf FROM terms t WHERE t.term = ?";
  55.  
  56. String strMaxConcept = "SELECT MAX(id) FROM article";
  57.  
  58. String strInlinks = "SELECT i.target_id, i.inlink FROM inlinks i WHERE i.target_id IN ";
  59.  
  60. String strLinks = "SELECT target_id FROM pagelinks WHERE source_id = ?";
  61.  
  62. int maxConceptId;
  63.  
  64. int[] ids;
  65. double[] values;
  66.  
  67. HashMap<String, Integer> freqMap = new HashMap<String, Integer>(30);
  68. HashMap<String, Double> tfidfMap = new HashMap<String, Double>(30);
  69. HashMap<String, Float> idfMap = new HashMap<String, Float>(30);
  70.  
  71. ArrayList<String> termList = new ArrayList<String>(30);
  72.  
  73. TIntIntHashMap inlinkMap;
  74.  
  75. static float LINK_ALPHA = 0.5f;
  76.  
  77. ConceptVectorSimilarity sim = new ConceptVectorSimilarity(new CosineScorer());
  78.  
  79. public void initDB() throws ClassNotFoundException, SQLException, IOException {
  80. // Load the JDBC driver
  81. String driverName = "com.mysql.jdbc.Driver"; // MySQL Connector
  82. Class.forName(driverName);
  83.  
  84. // read DB config
  85. InputStream is = ESASearcher.class.getResourceAsStream("/config/db.conf");
  86. BufferedReader br = new BufferedReader(new InputStreamReader(is));
  87. String serverName = br.readLine();
  88. String mydatabase = br.readLine();
  89. String username = br.readLine();
  90. String password = br.readLine();
  91. br.close();
  92.  
  93. // Create a connection to the database
  94. String url = "jdbc:mysql://" + serverName + "/" + mydatabase; // a JDBC url
  95. connection = DriverManager.getConnection(url, username, password);
  96.  
  97. pstmtQuery = connection.prepareStatement(strTermQuery);
  98. pstmtQuery.setFetchSize(1);
  99.  
  100. pstmtIdfQuery = connection.prepareStatement(strIdfQuery);
  101. pstmtIdfQuery.setFetchSize(1);
  102.  
  103. pstmtLinks = connection.prepareStatement(strLinks);
  104. pstmtLinks.setFetchSize(500);
  105.  
  106. stmtInlink = connection.createStatement();
  107. stmtInlink.setFetchSize(50);
  108.  
  109. ResultSet res = connection.createStatement().executeQuery(strMaxConcept);
  110. res.next();
  111. maxConceptId = res.getInt(1) + 1;
  112. }
  113.  
  114. private void clean(){
  115. freqMap.clear();
  116. tfidfMap.clear();
  117. idfMap.clear();
  118. termList.clear();
  119. inlinkMap.clear();
  120.  
  121. Arrays.fill(ids, 0);
  122. Arrays.fill(values, 0);
  123. }
  124.  
  125. public ESASearcher() throws ClassNotFoundException, SQLException, IOException{
  126. initDB();
  127. analyzer = new WikipediaAnalyzer();
  128.  
  129. ids = new int[maxConceptId];
  130. values = new double[maxConceptId];
  131.  
  132. inlinkMap = new TIntIntHashMap(300);
  133. }
  134.  
  135. @Override
  136. protected void finalize() throws Throwable {
  137. connection.close();
  138. super.finalize();
  139. }
  140.  
  141. /**
  142. * Retrieves full vector for regular features
  143. * @param query
  144. * @return Returns concept vector results exist, otherwise null
  145. * @throws IOException
  146. * @throws SQLException
  147. */
  148. public IConceptVector getConceptVector(String query) throws IOException, SQLException{
  149. String strTerm;
  150. int numTerms = 0;
  151. ResultSet rs;
  152. int doc;
  153. double score;
  154. int vint;
  155. double vdouble;
  156. double tf;
  157. double vsum;
  158. int plen;
  159. TokenStream ts = analyzer.tokenStream("contents",new StringReader(query));
  160.  
  161. this.clean();
  162.  
  163. for( int i=0; i<ids.length; i++ ) {
  164. ids[i] = i;
  165. }
  166.  
  167. ts.reset();
  168.  
  169. while (ts.incrementToken()) {
  170.  
  171. TermAttribute t = ts.getAttribute(TermAttribute.class);
  172. strTerm = t.term();
  173.  
  174. // record term IDF
  175. if(!idfMap.containsKey(strTerm)){
  176. pstmtIdfQuery.setBytes(1, strTerm.getBytes("UTF-8"));
  177. pstmtIdfQuery.execute();
  178.  
  179. rs = pstmtIdfQuery.getResultSet();
  180. if(rs.next()){
  181. idfMap.put(strTerm, rs.getFloat(1));
  182. }
  183. }
  184.  
  185. // records term counts for TF
  186. if(freqMap.containsKey(strTerm)){
  187. vint = freqMap.get(strTerm);
  188. freqMap.put(strTerm, vint+1);
  189. }
  190. else {
  191. freqMap.put(strTerm, 1);
  192. }
  193.  
  194. termList.add(strTerm);
  195.  
  196. numTerms++;
  197.  
  198. }
  199.  
  200. ts.end();
  201. ts.close();
  202.  
  203. if(numTerms == 0){
  204. return null;
  205. }
  206.  
  207. // calculate TF-IDF vector (normalized)
  208. vsum = 0;
  209. for(String tk : idfMap.keySet()){
  210. tf = 1.0 + Math.log(freqMap.get(tk));
  211. vdouble = (idfMap.get(tk) * tf);
  212. tfidfMap.put(tk, vdouble);
  213. vsum += vdouble * vdouble;
  214. }
  215. vsum = Math.sqrt(vsum);
  216.  
  217.  
  218. // comment this out for canceling query normalization
  219. for(String tk : idfMap.keySet()){
  220. vdouble = tfidfMap.get(tk);
  221. tfidfMap.put(tk, vdouble / vsum);
  222. }
  223.  
  224. score = 0;
  225. for (String tk : termList) {
  226.  
  227. pstmtQuery.setBytes(1, tk.getBytes("UTF-8"));
  228. pstmtQuery.execute();
  229.  
  230. rs = pstmtQuery.getResultSet();
  231.  
  232. if(rs.next()){
  233. final ByteArrayInputStream bais = new ByteArrayInputStream(rs.getBytes(1));
  234. final DataInputStream dis = new DataInputStream(bais);
  235.  
  236. /**
  237. * 4 bytes: int - length of array
  238. * 4 byte (doc) - 8 byte (tfidf) pairs
  239. */
  240.  
  241. plen = dis.readInt();
  242. for(int k = 0;k<plen;k++){
  243. doc = dis.readInt();
  244. score = dis.readFloat();
  245. values[doc] += score * tfidfMap.get(tk);
  246. }
  247.  
  248. bais.close();
  249. dis.close();
  250. }
  251.  
  252. }
  253.  
  254. // no result
  255. if(score == 0){
  256. return null;
  257. }
  258.  
  259. HeapSort.heapSort( values, ids );
  260.  
  261. IConceptVector newCv = new TroveConceptVector(ids.length);
  262. for( int i=ids.length-1; i>=0 && values[i] > 0; i-- ) {
  263. newCv.set( ids[i], values[i] / numTerms );
  264. }
  265.  
  266. return newCv;
  267. }
  268.  
  269.  
  270. /**
  271. * Returns trimmed form of concept vector
  272. * @param cv
  273. * @return
  274. */
  275. public IConceptVector getNormalVector(IConceptVector cv, int LIMIT){
  276. IConceptVector cv_normal = new TroveConceptVector( LIMIT);
  277. IConceptIterator it;
  278.  
  279. if(cv == null)
  280. return null;
  281.  
  282. it = cv.orderedIterator();
  283.  
  284. int count = 0;
  285. while(it.next()){
  286. if(count >= LIMIT) break;
  287. cv_normal.set(it.getId(), it.getValue());
  288. count++;
  289. }
  290.  
  291. return cv_normal;
  292. }
  293.  
  294. private TIntIntHashMap setInlinkCounts(Collection<Integer> ids) throws SQLException{
  295. inlinkMap.clear();
  296.  
  297. String inPart = "(";
  298.  
  299. for(int id: ids){
  300. inPart += id + ",";
  301. }
  302.  
  303. inPart = inPart.substring(0,inPart.length()-1) + ")";
  304.  
  305. // collect inlink counts
  306. ResultSet r = stmtInlink.executeQuery(strInlinks + inPart);
  307. while(r.next()){
  308. inlinkMap.put(r.getInt(1), r.getInt(2));
  309. }
  310.  
  311. return inlinkMap;
  312. }
  313.  
  314. private Collection<Integer> getLinks(int id) throws SQLException{
  315. ArrayList<Integer> links = new ArrayList<Integer>(100);
  316.  
  317. pstmtLinks.setInt(1, id);
  318.  
  319. ResultSet r = pstmtLinks.executeQuery();
  320. while(r.next()){
  321. links.add(r.getInt(1));
  322. }
  323.  
  324. return links;
  325. }
  326.  
  327.  
  328. public IConceptVector getLinkVector(IConceptVector cv, int limit) throws SQLException {
  329. if(cv == null)
  330. return null;
  331. return getLinkVector(cv, true, LINK_ALPHA, limit);
  332. }
  333.  
  334. /**
  335. * Computes secondary interpretation vector of regular features
  336. * @param cv
  337. * @param moreGeneral
  338. * @param ALPHA
  339. * @param LIMIT
  340. * @return
  341. * @throws SQLException
  342. */
  343. public IConceptVector getLinkVector(IConceptVector cv, boolean moreGeneral, double ALPHA, int LIMIT) throws SQLException {
  344. IConceptIterator it;
  345.  
  346. if(cv == null)
  347. return null;
  348.  
  349. it = cv.orderedIterator();
  350.  
  351. int count = 0;
  352. ArrayList<Integer> pages = new ArrayList<Integer>();
  353.  
  354. TIntFloatHashMap valueMap2 = new TIntFloatHashMap(1000);
  355. TIntFloatHashMap valueMap3 = new TIntFloatHashMap();
  356.  
  357. ArrayList<Integer> npages = new ArrayList<Integer>();
  358.  
  359. HashMap<Integer, Float> secondMap = new HashMap<Integer, Float>(1000);
  360.  
  361.  
  362. this.clean();
  363.  
  364. // collect article objects
  365. while(it.next()){
  366. pages.add(it.getId());
  367. valueMap2.put(it.getId(),(float) it.getValue());
  368. count++;
  369. }
  370.  
  371. // prepare inlink counts
  372. setInlinkCounts(pages);
  373.  
  374. for(int pid : pages){
  375. Collection<Integer> raw_links = getLinks(pid);
  376. if(raw_links.isEmpty()){
  377. continue;
  378. }
  379. ArrayList<Integer> links = new ArrayList<Integer>(raw_links.size());
  380.  
  381. final double inlink_factor_p = Math.log(inlinkMap.get(pid));
  382.  
  383. float origValue = valueMap2.get(pid);
  384.  
  385. setInlinkCounts(raw_links);
  386.  
  387. for(int lid : raw_links){
  388. final double inlink_factor_link = Math.log(inlinkMap.get(lid));
  389.  
  390. // check concept generality..
  391. if(inlink_factor_link - inlink_factor_p > 1){
  392. links.add(lid);
  393. }
  394. }
  395.  
  396. for(int lid : links){
  397. if(!valueMap2.containsKey(lid)){
  398. valueMap2.put(lid, 0.0f);
  399. npages.add(lid);
  400. }
  401. }
  402.  
  403.  
  404.  
  405. float linkedValue = 0.0f;
  406.  
  407. for(int lid : links){
  408. if(valueMap3.containsKey(lid)){
  409. linkedValue = valueMap3.get(lid);
  410. linkedValue += origValue;
  411. valueMap3.put(lid, linkedValue);
  412. }
  413. else {
  414. valueMap3.put(lid, origValue);
  415. }
  416. }
  417.  
  418. }
  419.  
  420.  
  421. // for(int pid : pages){
  422. // if(valueMap3.containsKey(pid)){
  423. // secondMap.put(pid, (float) (valueMap2.get(pid) + ALPHA * valueMap3.get(pid)));
  424. // }
  425. // else {
  426. // secondMap.put(pid, (float) (valueMap2.get(pid) ));
  427. // }
  428. // }
  429.  
  430. for(int pid : npages){
  431. secondMap.put(pid, (float) (ALPHA * valueMap3.get(pid)));
  432.  
  433. }
  434.  
  435.  
  436. //System.out.println("read links..");
  437.  
  438.  
  439. ArrayList<Integer> keys = new ArrayList(secondMap.keySet());
  440.  
  441. //Sort keys by values.
  442. final Map langForComp = secondMap;
  443. Collections.sort(keys,
  444. new Comparator(){
  445. public int compare(Object left, Object right){
  446. Integer leftKey = (Integer)left;
  447. Integer rightKey = (Integer)right;
  448.  
  449. Float leftValue = (Float)langForComp.get(leftKey);
  450. Float rightValue = (Float)langForComp.get(rightKey);
  451. return leftValue.compareTo(rightValue);
  452. }
  453. });
  454. Collections.reverse(keys);
  455.  
  456.  
  457.  
  458. IConceptVector cv_link = new TroveConceptVector(maxConceptId);
  459.  
  460. int c = 0;
  461. for(int p : keys){
  462. cv_link.set(p, secondMap.get(p));
  463. c++;
  464. if(c >= LIMIT){
  465. break;
  466. }
  467. }
  468.  
  469.  
  470. return cv_link;
  471. }
  472.  
  473. public IConceptVector getCombinedVector(String query) throws IOException, SQLException{
  474. IConceptVector cvBase = getConceptVector(query);
  475. IConceptVector cvNormal, cvLink;
  476.  
  477. if(cvBase == null){
  478. return null;
  479. }
  480.  
  481. cvNormal = getNormalVector(cvBase,10);
  482. cvLink = getLinkVector(cvNormal,5);
  483.  
  484. cvNormal.add(cvLink);
  485.  
  486. return cvNormal;
  487. }
  488.  
  489. /**
  490. * Calculate semantic relatedness between documents
  491. * @param doc1
  492. * @param doc2
  493. * @return returns relatedness if successful, -1 otherwise
  494. */
  495. public double getRelatedness(String doc1, String doc2){
  496. try {
  497. // IConceptVector c1 = getCombinedVector(doc1);
  498. // IConceptVector c2 = getCombinedVector(doc2);
  499. // IConceptVector c1 = getNormalVector(getConceptVector(doc1),10);
  500. // IConceptVector c2 = getNormalVector(getConceptVector(doc2),10);
  501.  
  502. IConceptVector c1 = getConceptVector(doc1);
  503. IConceptVector c2 = getConceptVector(doc2);
  504.  
  505. if(c1 == null || c2 == null){
  506. return 0;
  507. }
  508.  
  509. return sim.calcSimilarity(c1, c2);
  510.  
  511. }
  512. catch(Exception e){
  513. e.printStackTrace();
  514. return 0;
  515. }
  516.  
  517. }
  518.  
  519. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement