Guest User

Untitled

a guest
Apr 20th, 2018
96
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 40.99 KB | None | 0 0
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "Необходимо положить архив `data.tar.gz` в директорию `data`, распаковать, после чего: \n",
  8. "\n",
  9. "* изменить имя файла с правильными ответами на `test.qrel_clean`;\n",
  10. "* создать файл для предсказаний с названием `train.qrel_clean`;\n",
  11. "* поменять названия файлов в `eval.py`."
  12. ]
  13. },
  14. {
  15. "cell_type": "code",
  16. "execution_count": 1,
  17. "metadata": {
  18. "collapsed": false
  19. },
  20. "outputs": [
  21. {
  22. "name": "stdout",
  23. "output_type": "stream",
  24. "text": [
  25. "data/\n",
  26. "data/eval.py\n",
  27. "data/cran.qry\n",
  28. "data/qrel_clean\n",
  29. "data/cran.all.1400\n"
  30. ]
  31. }
  32. ],
  33. "source": [
  34. "!tar -xvzf data/data.tar.gz"
  35. ]
  36. },
  37. {
  38. "cell_type": "code",
  39. "execution_count": 2,
  40. "metadata": {
  41. "collapsed": false
  42. },
  43. "outputs": [],
  44. "source": [
  45. "import sys\n",
  46. "import os\n",
  47. "import abc\n",
  48. "import math\n",
  49. "import string\n",
  50. "import functools\n",
  51. "import subprocess\n",
  52. "\n",
  53. "import numpy as np\n",
  54. "import pandas as pd\n",
  55. "from pandas import DataFrame\n",
  56. "\n",
  57. "from itertools import chain, product\n",
  58. "from operator import itemgetter\n",
  59. "\n",
  60. "import nltk\n",
  61. "from nltk.tokenize import word_tokenize\n",
  62. "from nltk.stem import WordNetLemmatizer\n",
  63. "from nltk.stem import PorterStemmer\n",
  64. "from nltk.corpus import stopwords\n",
  65. "\n",
  66. "\n",
  67. "from typing import (\n",
  68. " Optional,\n",
  69. " Generator,\n",
  70. " List\n",
  71. ")\n",
  72. "\n",
  73. "from tqdm import tqdm\n",
  74. "\n",
  75. "stop_words = set(stopwords.words('english'))\n",
  76. "lemmatizer = WordNetLemmatizer()\n",
  77. "stemmer = PorterStemmer()\n",
  78. "\n",
  79. "DIRECTORY = \"data\"\n",
  80. "TEXTS_FILE = os.path.join(DIRECTORY, \"cran.all.1400\")\n",
  81. "QUERIES_FILE = os.path.join(DIRECTORY, \"cran.qry\")\n",
  82. "PREDICTION_FILE = os.path.join(DIRECTORY, \"train.qrel_clean\")\n",
  83. "CORRECT_ANSWERS_FILE = os.path.join(DIRECTORY, \"test.qrel_clean\")\n",
  84. "NUMBER_TEXTS = 1400\n",
  85. "NUMBER_QUERIES = 225"
  86. ]
  87. },
  88. {
  89. "cell_type": "markdown",
  90. "metadata": {},
  91. "source": [
  92. "### 1.Парсинг"
  93. ]
  94. },
  95. {
  96. "cell_type": "code",
  97. "execution_count": 3,
  98. "metadata": {
  99. "collapsed": true
  100. },
  101. "outputs": [],
  102. "source": [
  103. "class Text:\n",
  104. " __slots__ = [\"i\", \"t\", \"a\", \"b\", \"w\"]\n",
  105. " \n",
  106. " def __init__(self):\n",
  107. " self.i = None # type: str\n",
  108. " self.t = None # type: str\n",
  109. " self.a = None # type: str\n",
  110. " self.b = None # type: str\n",
  111. " self.w = None # type: str\n",
  112. "\n",
  113. " \n",
  114. "class Query:\n",
  115. " __slots__ = [\"i\", \"w\"]\n",
  116. " \n",
  117. " def __init__(self):\n",
  118. " self.i = None # type: str\n",
  119. " self.w = None # type: str\n",
  120. "\n",
  121. "\n",
  122. "def _read_file(filepath):\n",
  123. " # type: (str) -> str\n",
  124. " \n",
  125. " file = open(filepath)\n",
  126. " yield from file\n",
  127. " \n",
  128. "\n",
  129. "def _parse(gen, cls, line_starts):\n",
  130. " # type: (Generator[str]) -> Text\n",
  131. " \n",
  132. " def set_current_state(line):\n",
  133. " nonlocal current_state\n",
  134. " \n",
  135. " for i, (s, _) in enumerate(line_starts):\n",
  136. " if line.startswith(s):\n",
  137. " current_state = i\n",
  138. " \n",
  139. " def yield_text():\n",
  140. " t = cls()\n",
  141. " for i, (_, s) in enumerate(line_starts):\n",
  142. " setattr(t, s, ''.join(text_lists[i]))\n",
  143. " \n",
  144. " return t\n",
  145. " \n",
  146. " text_lists = [[] for _ in range(len(line_starts))]\n",
  147. " current_state = -1\n",
  148. "\n",
  149. " for line in chain(gen, ['.I']):\n",
  150. " set_current_state(line)\n",
  151. " \n",
  152. " if current_state == 0:\n",
  153. " if any(text_lists):\n",
  154. " yield yield_text()\n",
  155. " text_lists = [[] for _ in range(len(line_starts))]\n",
  156. " \n",
  157. " text_lists[current_state].append(line)\n",
  158. "\n",
  159. "\n",
  160. "def get_texts_gen(filepath):\n",
  161. " # type: (str) -> Generator[Text]\n",
  162. "\n",
  163. " gen = _read_file(filepath)\n",
  164. " texts = _parse(gen, Text, [('.I', 'i'), ('.T', 't'), ('.A', 'a'), ('.B', 'b'), ('.W', 'w')])\n",
  165. " return texts\n",
  166. "\n",
  167. "\n",
  168. "def get_queries_gen(filepath):\n",
  169. " # type: (str) -> Generator[Query]\n",
  170. " \n",
  171. " gen = _read_file(filepath)\n",
  172. " queries = _parse(gen, Query, [('.I', 'i'), ('.W', 'w')])\n",
  173. " return queries"
  174. ]
  175. },
  176. {
  177. "cell_type": "markdown",
  178. "metadata": {},
  179. "source": [
  180. "### Нормализация\n",
  181. "\n",
  182. "Используем библиотеку `nltk`. \n",
  183. "\n",
  184. "* Удалим из текста все цифры и символы пунктуации. \n",
  185. "* Разобъём на токены, удалим стоп слова и односимвольные токены.\n",
  186. "* Прведём стемминг и лемматизацию."
  187. ]
  188. },
  189. {
  190. "cell_type": "code",
  191. "execution_count": 4,
  192. "metadata": {
  193. "collapsed": true
  194. },
  195. "outputs": [],
  196. "source": [
  197. "class ItemTokens:\n",
  198. " __slots__ = [\"i\", \"tokens\"]\n",
  199. " \n",
  200. " def __init__(self, text, attr):\n",
  201. " self.i = text.i\n",
  202. " \n",
  203. " assert hasattr(text, attr)\n",
  204. " \n",
  205. " self.tokens = self._filter(self._clean(getattr(text, attr)))\n",
  206. " \n",
  207. " def _filter(self, s):\n",
  208. " stop_tokens = [\"I\", \"T\", \"A\", \"B\", \"W\"]\n",
  209. " return (stemmer.stem(lemmatizer.lemmatize(t)) for t in word_tokenize(s) \n",
  210. " if (t not in stop_words) and (t not in stop_tokens) and (len(t) >= 2))\n",
  211. " \n",
  212. " def _clean(self, s):\n",
  213. " for t in chain(string.punctuation, string.digits):\n",
  214. " s = s.replace(t, \" \")\n",
  215. " return s\n",
  216. " \n",
  217. " def __iter__(self):\n",
  218. " return self.tokens\n",
  219. "\n",
  220. " \n",
  221. "def get_item_tokens_gen(texts, attr):\n",
  222. " # type: (List[Text], str) -> Generator[TextTokens]\n",
  223. " \n",
  224. " item_tokens = (ItemTokens(text, attr) for text in texts)\n",
  225. " return item_tokens"
  226. ]
  227. },
  228. {
  229. "cell_type": "markdown",
  230. "metadata": {},
  231. "source": [
  232. "### 2. Инвертированный индекс"
  233. ]
  234. },
  235. {
  236. "cell_type": "code",
  237. "execution_count": 5,
  238. "metadata": {
  239. "collapsed": true
  240. },
  241. "outputs": [],
  242. "source": [
  243. "class InvIndex:\n",
  244. " def __init__(self, text_tokens_gen):\n",
  245. " columns = [\"doc_id\", \"token\", \"count\"]\n",
  246. " index = [\"doc_id\", \"token\"]\n",
  247. " \n",
  248. " def get_part_df():\n",
  249. " for doc_id, text_tokens in tqdm(enumerate(text_tokens_gen)):\n",
  250. " data = [(doc_id, token, 1) for token in text_tokens]\n",
  251. " df = DataFrame(data, columns=columns).groupby(by=index).sum()\n",
  252. " yield df\n",
  253. " \n",
  254. " self.df = pd.concat(get_part_df())\n",
  255. " self.df[\"count\"] = self.df[\"count\"].astype(np.float32)\n",
  256. " \n",
  257. " @functools.lru_cache(maxsize=256, typed=False)\n",
  258. " def get_l(self, doc_id=None):\n",
  259. " try:\n",
  260. " if doc_id is not None:\n",
  261. " return self.df.loc[doc_id, :][\"count\"].sum()\n",
  262. " \n",
  263. " return self.df.reset_index().groupby(\"doc_id\").sum()[\"count\"].mean()\n",
  264. " except KeyError:\n",
  265. " pass\n",
  266. " \n",
  267. " @functools.lru_cache(maxsize=256, typed=False)\n",
  268. " def get_f(self, t, doc_id=None):\n",
  269. " try:\n",
  270. " if doc_id is not None:\n",
  271. " return self.df.loc[(doc_id, t), :]\n",
  272. " \n",
  273. " return self.df.loc[(slice(None), t), :]\n",
  274. " except KeyError:\n",
  275. " pass\n",
  276. " \n",
  277. " @functools.lru_cache(maxsize=256, typed=False)\n",
  278. " def get_n(self, t=None):\n",
  279. " try:\n",
  280. " if t is None:\n",
  281. " return self.df[\"count\"].sum()\n",
  282. "\n",
  283. " return self.df.loc[(slice(None), t), :][\"count\"].sum()\n",
  284. " except KeyError:\n",
  285. " return 0 "
  286. ]
  287. },
  288. {
  289. "cell_type": "markdown",
  290. "metadata": {},
  291. "source": [
  292. "### 3. RSV\n",
  293. "\n",
  294. "#### RSVRankedList\n",
  295. "\n",
  296. "Формула из условия.\n",
  297. "\n",
  298. "#### RSV2RankedLIst\n",
  299. "\n",
  300. "В условии $IDF = \\log{(1 + \\frac{N - N_t + 0.5}{N_t + 0.5})}$, что по сути равно формуле из лекции $IDF = \\log{\\frac{N} {N_t}}$. Будем считать $IDF$ по формуле $\\log{\\frac{N - N_t}{N_t}}$.\n",
  301. "\n",
  302. "#### RSVNormRankedList\n",
  303. "\n",
  304. "В одном из заданий предлагалось пронормировать $RSV$ на сумму $IDF$ токенов запроса. Т.к. формула $IDF$ не зависит от документа, то по сути мы просто поделим $RSV(q, d)$ на константы, и список релевантных документов для запроса не изменится.\n",
  305. "\n",
  306. "#### RSVFullRankedList\n",
  307. "\n",
  308. "Полная формула из лекции."
  309. ]
  310. },
  311. {
  312. "cell_type": "code",
  313. "execution_count": 6,
  314. "metadata": {
  315. "collapsed": false
  316. },
  317. "outputs": [],
  318. "source": [
  319. "class RankedList:\n",
  320. " __metaclass__ = abc.ABCMeta\n",
  321. "\n",
  322. " @abc.abstractmethod\n",
  323. " def __call__(self, q, inv_index):\n",
  324. " \"\"\"Return sorted list of docs relevant to the query\"\"\"\n",
  325. "\n",
  326. "\n",
  327. "class RSVRankedList(RankedList):\n",
  328. " \"\"\"Formula from the task\"\"\"\n",
  329. " def __init__(self, k1, b):\n",
  330. " self.k1 = k1\n",
  331. " self.b = b\n",
  332. " \n",
  333. " def __call__(self, q, inv_index):\n",
  334. " rsv = {}\n",
  335. " N = inv_index.get_n()\n",
  336. " \n",
  337. " for t in q:\n",
  338. " Nt = inv_index.get_n(t)\n",
  339. " F = inv_index.get_f(t)\n",
  340. " idf = math.log(1.0 + (N - Nt + 0.5) / (Nt + 0.5))\n",
  341. " \n",
  342. " if F is not None:\n",
  343. " for index, row in F.iterrows():\n",
  344. " doc_id, ftd = index[0], row[\"count\"]\n",
  345. " Ld, L = inv_index.get_l(doc_id), inv_index.get_l()\n",
  346. " tf = ftd * (self.k1 + 1.) / (self.k1 * ((1. - self.b) + self.b * Ld / L) + ftd)\n",
  347. " rsv[doc_id] = rsv.get(doc_id, 0) + idf * tf\n",
  348. " return sorted(rsv.items(), key=itemgetter(1), reverse=True)[:10]\n",
  349. "\n",
  350. " \n",
  351. "class RSV2RankedList(RankedList):\n",
  352. " \"\"\"Formula from the lecture\"\"\"\n",
  353. " def __init__(self, k1, b):\n",
  354. " self.k1 = k1\n",
  355. " self.b = b\n",
  356. " \n",
  357. " def __call__(self, q, inv_index):\n",
  358. " rsv = {}\n",
  359. " N = inv_index.get_n()\n",
  360. " \n",
  361. " for t in q:\n",
  362. " Nt = inv_index.get_n(t)\n",
  363. " F = inv_index.get_f(t)\n",
  364. " idf = math.log((N - Nt + 0.5) / (Nt + 0.5))\n",
  365. " \n",
  366. " if F is not None:\n",
  367. " for index, row in F.iterrows():\n",
  368. " doc_id, ftd = index[0], row[\"count\"]\n",
  369. " Ld, L = inv_index.get_l(doc_id), inv_index.get_l()\n",
  370. " tf = ftd * (self.k1 + 1.) / (self.k1 * ((1. - self.b) + self.b * Ld / L) + ftd)\n",
  371. " rsv[doc_id] = rsv.get(doc_id, 0) + idf * tf\n",
  372. " return sorted(rsv.items(), key=itemgetter(1), reverse=True)[:10]\n",
  373. "\n",
  374. "\n",
  375. "class RSVNormRankedList(RankedList):\n",
  376. " \"\"\"Normed over sum of IDF's of terms\"\"\"\n",
  377. " def __init__(self, k1, b):\n",
  378. " self.k1 = k1\n",
  379. " self.b = b\n",
  380. " \n",
  381. " def __call__(self, q, inv_index):\n",
  382. " rsv = {}\n",
  383. " N = inv_index.get_n()\n",
  384. " idf_sum = 0.0\n",
  385. " \n",
  386. " for t in q:\n",
  387. " Nt = inv_index.get_n(t)\n",
  388. " F = inv_index.get_f(t)\n",
  389. " idf = math.log(1.0 + (N - Nt + 0.5) / (Nt + 0.5))\n",
  390. " idf_sum += idf\n",
  391. " \n",
  392. " if F is not None:\n",
  393. " for index, row in F.iterrows():\n",
  394. " doc_id, ftd = index[0], row[\"count\"]\n",
  395. " Ld, L = inv_index.get_l(doc_id), inv_index.get_l()\n",
  396. " tf = ftd * (self.k1 + 1.) / (self.k1 * ((1. - self.b) + self.b * Ld / L) + ftd)\n",
  397. " rsv[doc_id] = rsv.get(doc_id, 0) + idf * tf\n",
  398. " \n",
  399. " for doc_id in rsv.keys():\n",
  400. " rsv[doc_id] /= idf_sum\n",
  401. "\n",
  402. " return sorted(rsv.items(), key=itemgetter(1), reverse=True)[:10]\n",
  403. "\n",
  404. "\n",
  405. "class RSVFullRankedList(RankedList):\n",
  406. " \"\"\"Full formula\"\"\"\n",
  407. " def __init__(self, k1, k2, b):\n",
  408. " self.k1 = k1\n",
  409. " self.k2 = k2\n",
  410. " self.b = b\n",
  411. " \n",
  412. " def __call__(self, q, inv_index):\n",
  413. " rsv = {}\n",
  414. " N = inv_index.get_n()\n",
  415. " \n",
  416. " q, ftq = list(q), {}\n",
  417. " for t in q:\n",
  418. " ftq[t] = ftq.get(t, 0) + 1\n",
  419. " \n",
  420. " for t in q:\n",
  421. " Nt = inv_index.get_n(t)\n",
  422. " F = inv_index.get_f(t)\n",
  423. " idf = math.log(1.0 + (N - Nt + 0.5) / (Nt + 0.5))\n",
  424. " tf_q = (self.k2 + 1) * ftq[t] / (self.k2 + ftq[t])\n",
  425. " \n",
  426. " if F is not None:\n",
  427. " for index, row in F.iterrows():\n",
  428. " doc_id, ftd = index[0], row[\"count\"]\n",
  429. " Ld, L = inv_index.get_l(doc_id), inv_index.get_l()\n",
  430. " tf = ftd * (self.k1 + 1.) / (self.k1 * ((1. - self.b) + self.b * Ld / L) + ftd)\n",
  431. " rsv[doc_id] = rsv.get(doc_id, 0) + idf * tf * tf_q\n",
  432. " return sorted(rsv.items(), key=itemgetter(1), reverse=True)[:10]\n"
  433. ]
  434. },
  435. {
  436. "cell_type": "markdown",
  437. "metadata": {},
  438. "source": [
  439. "### 4. Оценка качества поиска\n",
  440. "\n",
  441. "По значениям метрик можно понять, что среди предсказанных нами документов действительно релевантными оказалось $29.02\\%$, но это составило $42.46\\%$ от числа всех релевантных документов.\n",
  442. "\n",
  443. "С одной стороны, чем реже встречается токен, тем больше $IDF$, с другой, чем больше встречается токен в документе, тем больше $TF$. Если же слово часто встречается во многих документах, то оно будет вносить примерно одинаковый вклад во все соответствующие $RSV$, а слишком редко встречающееся слово, может не попастаться в самом запросе, и скорее всего просто содержит ошибку. Наилучшими являются слова, которые очень часто встречаются в некоторых документах, тем самым они будут вносить существенный вклад в $RSV$ этих документов.\n",
  444. "\n",
  445. "Всего уникальных токенов $4270$. Среднее количество вхождений одного токена - $29.97$. Самый популярный токен - flow - встречается $2080$ раз, а $50\\%$ токенов встречаются не более $3$-х раз. "
  446. ]
  447. },
  448. {
  449. "cell_type": "code",
  450. "execution_count": 7,
  451. "metadata": {
  452. "collapsed": false
  453. },
  454. "outputs": [
  455. {
  456. "name": "stderr",
  457. "output_type": "stream",
  458. "text": [
  459. "1400it [00:12, 111.30it/s]\n",
  460. "225it [07:23, 3.39s/it]\n"
  461. ]
  462. }
  463. ],
  464. "source": [
  465. "texts_gen = get_texts_gen(TEXTS_FILE)\n",
  466. "text_tokens_gen = get_item_tokens_gen(texts_gen, \"w\")\n",
  467. "\n",
  468. "queries_gen = get_queries_gen(QUERIES_FILE)\n",
  469. "query_tokens_gen = get_item_tokens_gen(queries_gen, \"w\")\n",
  470. "\n",
  471. "rsv = RSVRankedList(k1=1.2, b=0.75)\n",
  472. "inv_index = InvIndex(text_tokens_gen)\n",
  473. "with open(PREDICTION_FILE, \"w\") as fout:\n",
  474. " for query_id, query in tqdm(enumerate(query_tokens_gen)):\n",
  475. " ranked_list = rsv(query, inv_index)\n",
  476. " for doc_id, _ in ranked_list:\n",
  477. " fout.write(\"{} {}\\n\".format(query_id + 1, doc_id + 1))"
  478. ]
  479. },
  480. {
  481. "cell_type": "code",
  482. "execution_count": 8,
  483. "metadata": {
  484. "collapsed": false
  485. },
  486. "outputs": [
  487. {
  488. "name": "stdout",
  489. "output_type": "stream",
  490. "text": [
  491. "mean precision: 0.2902222222222223\n",
  492. "mean recall: 0.42462172636971746\n",
  493. "mean F-measure: 0.344787589721076\n",
  494. "MAP@10: 0.35587606100053193\n"
  495. ]
  496. }
  497. ],
  498. "source": [
  499. "!cd data && python3 eval.py"
  500. ]
  501. },
  502. {
  503. "cell_type": "code",
  504. "execution_count": 9,
  505. "metadata": {
  506. "collapsed": false
  507. },
  508. "outputs": [
  509. {
  510. "data": {
  511. "text/plain": [
  512. "token\n",
  513. "flow 2080.0\n",
  514. "pressur 1390.0\n",
  515. "number 1345.0\n",
  516. "boundari 1214.0\n",
  517. "layer 1161.0\n",
  518. "result 1087.0\n",
  519. "effect 997.0\n",
  520. "method 888.0\n",
  521. "theori 883.0\n",
  522. "bodi 854.0\n",
  523. "solut 849.0\n",
  524. "heat 844.0\n",
  525. "wing 838.0\n",
  526. "mach 822.0\n",
  527. "equat 777.0\n",
  528. "Name: count, dtype: float32"
  529. ]
  530. },
  531. "execution_count": 9,
  532. "metadata": {},
  533. "output_type": "execute_result"
  534. }
  535. ],
  536. "source": [
  537. "stat = inv_index.df.reset_index().groupby(\"token\").sum().sort_values(by=\"count\", ascending=False)[\"count\"]\n",
  538. "stat.head(15)"
  539. ]
  540. },
  541. {
  542. "cell_type": "code",
  543. "execution_count": 10,
  544. "metadata": {
  545. "collapsed": false
  546. },
  547. "outputs": [
  548. {
  549. "data": {
  550. "text/plain": [
  551. "count 4270.000000\n",
  552. "mean 29.973770\n",
  553. "std 95.428993\n",
  554. "min 1.000000\n",
  555. "25% 1.000000\n",
  556. "50% 3.000000\n",
  557. "75% 16.000000\n",
  558. "max 2080.000000\n",
  559. "Name: count, dtype: float64"
  560. ]
  561. },
  562. "execution_count": 10,
  563. "metadata": {},
  564. "output_type": "execute_result"
  565. }
  566. ],
  567. "source": [
  568. "stat.describe()"
  569. ]
  570. },
  571. {
  572. "cell_type": "markdown",
  573. "metadata": {},
  574. "source": [
  575. "### 5. Сравнение значений метрик для полных текстов и заголовков\n",
  576. "\n",
  577. "Аннотации содержат больше токенов чем заголовок, а значит больше шанс, что они содержат слова, специфичные для текста. При этом они содержат больше \"общезначимых\" слов, но у них будет меньше $tf-idf$, а значит и влияние на $RSV$.\n",
  578. "\n",
  579. "| | precision | recall | F-measure | MAP@10 |\n",
  580. "|-----------|-----------|--------|-----------|--------|\n",
  581. "| аннотации | 0.2902 | 0.4246 | 0.3448 | 0.3558 |\n",
  582. "| заголовки | 0.2577 | 0.3735 | 0.3050 | 0.2982 |"
  583. ]
  584. },
  585. {
  586. "cell_type": "code",
  587. "execution_count": 11,
  588. "metadata": {
  589. "collapsed": false
  590. },
  591. "outputs": [
  592. {
  593. "name": "stderr",
  594. "output_type": "stream",
  595. "text": [
  596. "1400it [00:05, 258.37it/s]\n",
  597. "225it [00:57, 3.90it/s]\n"
  598. ]
  599. }
  600. ],
  601. "source": [
  602. "texts_gen = get_texts_gen(TEXTS_FILE)\n",
  603. "text_tokens_gen = get_item_tokens_gen(texts_gen, \"t\")\n",
  604. "\n",
  605. "queries_gen = get_queries_gen(QUERIES_FILE)\n",
  606. "query_tokens_gen = get_item_tokens_gen(queries_gen, \"w\")\n",
  607. "\n",
  608. "rsv = RSVRankedList(k1=1.2, b=0.75)\n",
  609. "inv_index = InvIndex(text_tokens_gen)\n",
  610. "with open(PREDICTION_FILE, \"w\") as fout:\n",
  611. " for query_id, query in tqdm(enumerate(query_tokens_gen)):\n",
  612. " ranked_list = rsv(query, inv_index)\n",
  613. " for doc_id, _ in ranked_list:\n",
  614. " fout.write(\"{} {}\\n\".format(query_id + 1, doc_id + 1))"
  615. ]
  616. },
  617. {
  618. "cell_type": "code",
  619. "execution_count": 12,
  620. "metadata": {
  621. "collapsed": false
  622. },
  623. "outputs": [
  624. {
  625. "name": "stdout",
  626. "output_type": "stream",
  627. "text": [
  628. "mean precision: 0.25777777777777794\n",
  629. "mean recall: 0.3735774885450877\n",
  630. "mean F-measure: 0.3050579601111924\n",
  631. "MAP@10: 0.29823958945718215\n"
  632. ]
  633. }
  634. ],
  635. "source": [
  636. "!cd data && python3 eval.py"
  637. ]
  638. },
  639. {
  640. "cell_type": "markdown",
  641. "metadata": {
  642. "collapsed": true
  643. },
  644. "source": [
  645. "### 6. Подбор параметров\n",
  646. "\n",
  647. "Если бы в $TF$ учитывались только $f(t, d)$ [например, $\\frac{(k1 + 1)f_{t, d}}{k1 + f_{t, d}}$)], то чем длинее был бы текст, тем больше это значение могло бы быть для токена. Поэтому $\\frac{L_d}{\\overline{L}}$ призван уравновесить шансы длинных и коротких текстов, а параметр $b$ обозначает насколько важно нам учитывать этот факт. А $k_1$ обозначает что нам важнее, просто количество вхождений токена или количество вхождений токена при учёте длины текста.\n",
  648. "\n",
  649. "При фиксированном $k_1$ значения метрик увеличиваются при росте $b$ до $\\approx 0.7-0.8$, и уменьшаются, при дальнейшем увеличении $b$. И оптимальными кажутся два пары $k_1$ и $b$ (см. таблица).\n",
  650. "\n",
  651. "| k1 | b | precision | recall | F-score | MAP@10 |\n",
  652. "|-----|-----|-----------|--------|---------|--------|\n",
  653. "| 1.2 | 0.0 | 0.2502 | 0.3717 | 0.2991 | 0.2866 |\n",
  654. "| 1.2 | 0.3 | 0.2724 | 0.4025 | 0.3249 | 0.3252 |\n",
  655. "| <b> 1.2 | <b> 0.7 | <b> 0.2876 | <b> 0.4201 | <b>0.3414 | <b>0.3544 |\n",
  656. "| 1.2 | 1.0 | 0.2898 | 0.4234 | 0.3440 | 0.3547 |\n",
  657. "| 1.4 | 0.0 | 0.2507 | 0.3731 | 0.2999 | 0.2858 |\n",
  658. "| 1.4 | 0.3 | 0.2733 | 0.4033 | 0.3258 | 0.3277 |\n",
  659. "| 1.4 | 0.7 | 0.2942 | 0.4309 | 0.3497 | 0.3602 |\n",
  660. "| 1.4 | 1.0 | 0.2916 | 0.4260 | 0.3462 | 0.3560 |\n",
  661. "| 1.7 | 0.0 | 0.2498 | 0.3693 | 0.2980 | 0.2835 |\n",
  662. "| 1.7 | 0.3 | 0.2791 | 0.4122 | 0.3328 | 0.3322 |\n",
  663. "| 1.7 | 0.7 | 0.2951 | 0.4322 | 0.3507 | 0.3619 |\n",
  664. "| 1.7 | 1.0 | 0.2924 | 0.4273 | 0.3472 | 0.3472 |\n",
  665. "| 2.0 | 0.0 | 0.2498 | 0.3691 | 0.2979 | 0.2829 |\n",
  666. "| 2.0 | 0.3 | 0.2818 | 0.4139 | 0.3353 | 0.3346 |\n",
  667. "| <b> 2.0 | <b> 0.7 | <b> 0.2969 | <b> 0.4341 | <b> 0.3526 | <b> 0.3628 |\n",
  668. "| 2.0 | 1.0 | 0.2947 | 0.4322 | 0.3504 | 0.3573 |"
  669. ]
  670. },
  671. {
  672. "cell_type": "code",
  673. "execution_count": 13,
  674. "metadata": {
  675. "collapsed": false
  676. },
  677. "outputs": [
  678. {
  679. "name": "stderr",
  680. "output_type": "stream",
  681. "text": [
  682. "1400it [00:10, 128.92it/s]\n",
  683. "225it [04:28, 1.19s/it]\n",
  684. "0it [00:00, ?it/s]"
  685. ]
  686. },
  687. {
  688. "name": "stdout",
  689. "output_type": "stream",
  690. "text": [
  691. "k1=1.2 b=0.0\n",
  692. " mean precision: 0.25022222222222235\n",
  693. "mean recall: 0.37170986708272963\n",
  694. "mean F-measure: 0.2991004019982698\n",
  695. "MAP@10: 0.28660730032753856\n",
  696. "\n"
  697. ]
  698. },
  699. {
  700. "name": "stderr",
  701. "output_type": "stream",
  702. "text": [
  703. "225it [04:27, 1.19s/it]\n",
  704. "0it [00:00, ?it/s]"
  705. ]
  706. },
  707. {
  708. "name": "stdout",
  709. "output_type": "stream",
  710. "text": [
  711. "k1=1.2 b=0.3\n",
  712. " mean precision: 0.27244444444444454\n",
  713. "mean recall: 0.4024949639631608\n",
  714. "mean F-measure: 0.32494032940629947\n",
  715. "MAP@10: 0.32524341283838654\n",
  716. "\n"
  717. ]
  718. },
  719. {
  720. "name": "stderr",
  721. "output_type": "stream",
  722. "text": [
  723. "225it [04:28, 1.19s/it]\n",
  724. "0it [00:00, ?it/s]"
  725. ]
  726. },
  727. {
  728. "name": "stdout",
  729. "output_type": "stream",
  730. "text": [
  731. "k1=1.2 b=0.7\n",
  732. " mean precision: 0.2875555555555556\n",
  733. "mean recall: 0.4200675028154939\n",
  734. "mean F-measure: 0.3414042058522218\n",
  735. "MAP@10: 0.3544446656028666\n",
  736. "\n"
  737. ]
  738. },
  739. {
  740. "name": "stderr",
  741. "output_type": "stream",
  742. "text": [
  743. "225it [04:28, 1.19s/it]\n",
  744. "0it [00:00, ?it/s]"
  745. ]
  746. },
  747. {
  748. "name": "stdout",
  749. "output_type": "stream",
  750. "text": [
  751. "k1=1.2 b=1.0\n",
  752. " mean precision: 0.2897777777777779\n",
  753. "mean recall: 0.42342078912050435\n",
  754. "mean F-measure: 0.34407790769930774\n",
  755. "MAP@10: 0.3546792670977857\n",
  756. "\n"
  757. ]
  758. },
  759. {
  760. "name": "stderr",
  761. "output_type": "stream",
  762. "text": [
  763. "225it [04:32, 1.21s/it]\n",
  764. "0it [00:00, ?it/s]"
  765. ]
  766. },
  767. {
  768. "name": "stdout",
  769. "output_type": "stream",
  770. "text": [
  771. "k1=1.4 b=0.0\n",
  772. " mean precision: 0.2506666666666668\n",
  773. "mean recall: 0.3731463493139485\n",
  774. "mean F-measure: 0.299882654466029\n",
  775. "MAP@10: 0.28581725035693295\n",
  776. "\n"
  777. ]
  778. },
  779. {
  780. "name": "stderr",
  781. "output_type": "stream",
  782. "text": [
  783. "225it [04:28, 1.19s/it]\n",
  784. "0it [00:00, ?it/s]"
  785. ]
  786. },
  787. {
  788. "name": "stdout",
  789. "output_type": "stream",
  790. "text": [
  791. "k1=1.4 b=0.3\n",
  792. " mean precision: 0.27333333333333343\n",
  793. "mean recall: 0.4033278430437178\n",
  794. "mean F-measure: 0.3258438569079453\n",
  795. "MAP@10: 0.3277454074633971\n",
  796. "\n"
  797. ]
  798. },
  799. {
  800. "name": "stderr",
  801. "output_type": "stream",
  802. "text": [
  803. "225it [04:37, 1.23s/it]\n",
  804. "0it [00:00, ?it/s]"
  805. ]
  806. },
  807. {
  808. "name": "stdout",
  809. "output_type": "stream",
  810. "text": [
  811. "k1=1.4 b=0.7\n",
  812. " mean precision: 0.2942222222222223\n",
  813. "mean recall: 0.4308948236428148\n",
  814. "mean F-measure: 0.3496782575425439\n",
  815. "MAP@10: 0.36022575165868814\n",
  816. "\n"
  817. ]
  818. },
  819. {
  820. "name": "stderr",
  821. "output_type": "stream",
  822. "text": [
  823. "225it [05:17, 1.41s/it]\n",
  824. "0it [00:00, ?it/s]"
  825. ]
  826. },
  827. {
  828. "name": "stdout",
  829. "output_type": "stream",
  830. "text": [
  831. "k1=1.4 b=1.0\n",
  832. " mean precision: 0.29155555555555573\n",
  833. "mean recall: 0.4260349113816855\n",
  834. "mean F-measure: 0.34619424587426306\n",
  835. "MAP@10: 0.35597095266089984\n",
  836. "\n"
  837. ]
  838. },
  839. {
  840. "name": "stderr",
  841. "output_type": "stream",
  842. "text": [
  843. "225it [05:08, 1.37s/it]\n",
  844. "0it [00:00, ?it/s]"
  845. ]
  846. },
  847. {
  848. "name": "stdout",
  849. "output_type": "stream",
  850. "text": [
  851. "k1=1.7 b=0.0\n",
  852. " mean precision: 0.24977777777777785\n",
  853. "mean recall: 0.36934844968271546\n",
  854. "mean F-measure: 0.29801688539612997\n",
  855. "MAP@10: 0.2835236639511774\n",
  856. "\n"
  857. ]
  858. },
  859. {
  860. "name": "stderr",
  861. "output_type": "stream",
  862. "text": [
  863. "225it [05:01, 1.34s/it]\n",
  864. "0it [00:00, ?it/s]"
  865. ]
  866. },
  867. {
  868. "name": "stdout",
  869. "output_type": "stream",
  870. "text": [
  871. "k1=1.7 b=0.3\n",
  872. " mean precision: 0.2791111111111111\n",
  873. "mean recall: 0.41221055909310067\n",
  874. "mean F-measure: 0.33284808539625815\n",
  875. "MAP@10: 0.33223726029506456\n",
  876. "\n"
  877. ]
  878. },
  879. {
  880. "name": "stderr",
  881. "output_type": "stream",
  882. "text": [
  883. "225it [05:02, 1.34s/it]\n",
  884. "0it [00:00, ?it/s]"
  885. ]
  886. },
  887. {
  888. "name": "stdout",
  889. "output_type": "stream",
  890. "text": [
  891. "k1=1.7 b=0.7\n",
  892. " mean precision: 0.2951111111111112\n",
  893. "mean recall: 0.4322190950141451\n",
  894. "mean F-measure: 0.35074208742843543\n",
  895. "MAP@10: 0.36186320861677995\n",
  896. "\n"
  897. ]
  898. },
  899. {
  900. "name": "stderr",
  901. "output_type": "stream",
  902. "text": [
  903. "225it [05:28, 1.46s/it]\n",
  904. "0it [00:00, ?it/s]"
  905. ]
  906. },
  907. {
  908. "name": "stdout",
  909. "output_type": "stream",
  910. "text": [
  911. "k1=1.7 b=1.0\n",
  912. " mean precision: 0.2924444444444446\n",
  913. "mean recall: 0.42730186664864084\n",
  914. "mean F-measure: 0.3472391732369003\n",
  915. "MAP@10: 0.35618574298031985\n",
  916. "\n"
  917. ]
  918. },
  919. {
  920. "name": "stderr",
  921. "output_type": "stream",
  922. "text": [
  923. "225it [05:01, 1.34s/it]\n",
  924. "0it [00:00, ?it/s]"
  925. ]
  926. },
  927. {
  928. "name": "stdout",
  929. "output_type": "stream",
  930. "text": [
  931. "k1=2.0 b=0.0\n",
  932. " mean precision: 0.24977777777777788\n",
  933. "mean recall: 0.3691234216243541\n",
  934. "mean F-measure: 0.297943607374251\n",
  935. "MAP@10: 0.2828864365499286\n",
  936. "\n"
  937. ]
  938. },
  939. {
  940. "name": "stderr",
  941. "output_type": "stream",
  942. "text": [
  943. "225it [05:18, 1.42s/it]\n",
  944. "0it [00:00, ?it/s]"
  945. ]
  946. },
  947. {
  948. "name": "stdout",
  949. "output_type": "stream",
  950. "text": [
  951. "k1=2.0 b=0.3\n",
  952. " mean precision: 0.2817777777777778\n",
  953. "mean recall: 0.4138503477846135\n",
  954. "mean F-measure: 0.3352763554148436\n",
  955. "MAP@10: 0.33462170501945643\n",
  956. "\n"
  957. ]
  958. },
  959. {
  960. "name": "stderr",
  961. "output_type": "stream",
  962. "text": [
  963. "225it [05:07, 1.37s/it]\n",
  964. "0it [00:00, ?it/s]"
  965. ]
  966. },
  967. {
  968. "name": "stdout",
  969. "output_type": "stream",
  970. "text": [
  971. "k1=2.0 b=0.7\n",
  972. " mean precision: 0.296888888888889\n",
  973. "mean recall: 0.4341141567425401\n",
  974. "mean F-measure: 0.35262143001032775\n",
  975. "MAP@10: 0.36280025965118545\n",
  976. "\n"
  977. ]
  978. },
  979. {
  980. "name": "stderr",
  981. "output_type": "stream",
  982. "text": [
  983. "225it [05:05, 1.36s/it]"
  984. ]
  985. },
  986. {
  987. "name": "stdout",
  988. "output_type": "stream",
  989. "text": [
  990. "k1=2.0 b=1.0\n",
  991. " mean precision: 0.29466666666666685\n",
  992. "mean recall: 0.4321510729978472\n",
  993. "mean F-measure: 0.3504056360415119\n",
  994. "MAP@10: 0.3572706279219507\n",
  995. "\n"
  996. ]
  997. },
  998. {
  999. "name": "stderr",
  1000. "output_type": "stream",
  1001. "text": [
  1002. "\n"
  1003. ]
  1004. }
  1005. ],
  1006. "source": [
  1007. "texts_gen = get_texts_gen(TEXTS_FILE)\n",
  1008. "text_tokens_gen = get_item_tokens_gen(texts_gen, \"w\")\n",
  1009. "inv_index = InvIndex(text_tokens_gen)\n",
  1010. "\n",
  1011. "for k1, b in product([1.2, 1.4, 1.7, 2.0], [0.0, 0.3, 0.7, 1.0]):\n",
  1012. "\n",
  1013. " queries_gen = get_queries_gen(QUERIES_FILE)\n",
  1014. " query_tokens_gen = get_item_tokens_gen(queries_gen, \"w\")\n",
  1015. "\n",
  1016. " rsv = RSVRankedList(k1=k1, b=b)\n",
  1017. " with open(PREDICTION_FILE, \"w\") as fout:\n",
  1018. " for query_id, query in tqdm(enumerate(query_tokens_gen)):\n",
  1019. " ranked_list = rsv(query, inv_index)\n",
  1020. " for doc_id, _ in ranked_list:\n",
  1021. " fout.write(\"{} {}\\n\".format(query_id + 1, doc_id + 1))\n",
  1022. " \n",
  1023. " with subprocess.Popen(\"cd data && python3 eval.py\", shell=True, stdout=subprocess.PIPE) as p:\n",
  1024. " print(\"k1={} b={}\\n {}\".format(k1, b, p.stdout.read().decode()))"
  1025. ]
  1026. },
  1027. {
  1028. "cell_type": "markdown",
  1029. "metadata": {
  1030. "collapsed": true
  1031. },
  1032. "source": [
  1033. "### 7. Другая формула вычисления IDF"
  1034. ]
  1035. },
  1036. {
  1037. "cell_type": "code",
  1038. "execution_count": 14,
  1039. "metadata": {
  1040. "collapsed": false
  1041. },
  1042. "outputs": [
  1043. {
  1044. "name": "stderr",
  1045. "output_type": "stream",
  1046. "text": [
  1047. "1400it [00:11, 126.09it/s]\n",
  1048. "225it [04:47, 1.28s/it]\n"
  1049. ]
  1050. }
  1051. ],
  1052. "source": [
  1053. "texts_gen = get_texts_gen(TEXTS_FILE)\n",
  1054. "text_tokens_gen = get_item_tokens_gen(texts_gen, \"w\")\n",
  1055. "\n",
  1056. "queries_gen = get_queries_gen(QUERIES_FILE)\n",
  1057. "query_tokens_gen = get_item_tokens_gen(queries_gen, \"w\")\n",
  1058. "\n",
  1059. "rsv = RSV2RankedList(k1=1.2, b=0.75)\n",
  1060. "inv_index = InvIndex(text_tokens_gen)\n",
  1061. "with open(PREDICTION_FILE, \"w\") as fout:\n",
  1062. " for query_id, query in tqdm(enumerate(query_tokens_gen)):\n",
  1063. " ranked_list = rsv(query, inv_index)\n",
  1064. " for doc_id, _ in ranked_list:\n",
  1065. " fout.write(\"{} {}\\n\".format(query_id + 1, doc_id + 1))"
  1066. ]
  1067. },
  1068. {
  1069. "cell_type": "code",
  1070. "execution_count": 15,
  1071. "metadata": {
  1072. "collapsed": false
  1073. },
  1074. "outputs": [
  1075. {
  1076. "name": "stdout",
  1077. "output_type": "stream",
  1078. "text": [
  1079. "mean precision: 0.29066666666666674\n",
  1080. "mean recall: 0.42506617081416204\n",
  1081. "mean F-measure: 0.34524772516567925\n",
  1082. "MAP@10: 0.35587112272892696\n"
  1083. ]
  1084. }
  1085. ],
  1086. "source": [
  1087. "!cd data && python3 eval.py"
  1088. ]
  1089. },
  1090. {
  1091. "cell_type": "markdown",
  1092. "metadata": {},
  1093. "source": [
  1094. "### 8. Нормировка\n",
  1095. "\n",
  1096. "Зачения метрик совпадают с изначальными. Нормировка ничего не изменила."
  1097. ]
  1098. },
  1099. {
  1100. "cell_type": "code",
  1101. "execution_count": 16,
  1102. "metadata": {
  1103. "collapsed": false
  1104. },
  1105. "outputs": [
  1106. {
  1107. "name": "stderr",
  1108. "output_type": "stream",
  1109. "text": [
  1110. "1400it [00:10, 128.95it/s]\n",
  1111. "225it [05:03, 1.35s/it]\n"
  1112. ]
  1113. }
  1114. ],
  1115. "source": [
  1116. "texts_gen = get_texts_gen(TEXTS_FILE)\n",
  1117. "text_tokens_gen = get_item_tokens_gen(texts_gen, \"w\")\n",
  1118. "\n",
  1119. "queries_gen = get_queries_gen(QUERIES_FILE)\n",
  1120. "query_tokens_gen = get_item_tokens_gen(queries_gen, \"w\")\n",
  1121. "\n",
  1122. "rsv = RSVNormRankedList(k1=1.2, b=0.75)\n",
  1123. "inv_index = InvIndex(text_tokens_gen)\n",
  1124. "with open(PREDICTION_FILE, \"w\") as fout:\n",
  1125. " for query_id, query in tqdm(enumerate(query_tokens_gen)):\n",
  1126. " ranked_list = rsv(query, inv_index)\n",
  1127. " for doc_id, _ in ranked_list:\n",
  1128. " fout.write(\"{} {}\\n\".format(query_id + 1, doc_id + 1))"
  1129. ]
  1130. },
  1131. {
  1132. "cell_type": "code",
  1133. "execution_count": 17,
  1134. "metadata": {
  1135. "collapsed": false
  1136. },
  1137. "outputs": [
  1138. {
  1139. "name": "stdout",
  1140. "output_type": "stream",
  1141. "text": [
  1142. "mean precision: 0.2902222222222223\n",
  1143. "mean recall: 0.42462172636971746\n",
  1144. "mean F-measure: 0.344787589721076\n",
  1145. "MAP@10: 0.35587606100053193\n"
  1146. ]
  1147. }
  1148. ],
  1149. "source": [
  1150. "!cd data && python3 eval.py"
  1151. ]
  1152. },
  1153. {
  1154. "cell_type": "markdown",
  1155. "metadata": {},
  1156. "source": [
  1157. "### 9. Общая формула вычисления RSV\n",
  1158. "\n",
  1159. "Добавления множетеля $TF(t, q)$ - попытка больше учитывать те токены, которые часто встречаются в запросе. Но в нашем случае, это только ухудшает значения метрик. Оптимальное значение при $k2 = 0$."
  1160. ]
  1161. },
  1162. {
  1163. "cell_type": "code",
  1164. "execution_count": 18,
  1165. "metadata": {
  1166. "collapsed": false
  1167. },
  1168. "outputs": [
  1169. {
  1170. "name": "stderr",
  1171. "output_type": "stream",
  1172. "text": [
  1173. "1400it [00:11, 120.43it/s]\n",
  1174. "225it [04:58, 1.33s/it]\n",
  1175. "0it [00:00, ?it/s]"
  1176. ]
  1177. },
  1178. {
  1179. "name": "stdout",
  1180. "output_type": "stream",
  1181. "text": [
  1182. "k2=0.0 \n",
  1183. " mean precision: 0.2902222222222223\n",
  1184. "mean recall: 0.42462172636971746\n",
  1185. "mean F-measure: 0.344787589721076\n",
  1186. "MAP@10: 0.35587606100053193\n",
  1187. "\n"
  1188. ]
  1189. },
  1190. {
  1191. "name": "stderr",
  1192. "output_type": "stream",
  1193. "text": [
  1194. "225it [05:02, 1.34s/it]\n",
  1195. "0it [00:00, ?it/s]"
  1196. ]
  1197. },
  1198. {
  1199. "name": "stdout",
  1200. "output_type": "stream",
  1201. "text": [
  1202. "k2=1.0 \n",
  1203. " mean precision: 0.28888888888888903\n",
  1204. "mean recall: 0.4202494319974231\n",
  1205. "mean F-measure: 0.3424025691184888\n",
  1206. "MAP@10: 0.35060137174211237\n",
  1207. "\n"
  1208. ]
  1209. },
  1210. {
  1211. "name": "stderr",
  1212. "output_type": "stream",
  1213. "text": [
  1214. "225it [04:59, 1.33s/it]\n",
  1215. "0it [00:00, ?it/s]"
  1216. ]
  1217. },
  1218. {
  1219. "name": "stdout",
  1220. "output_type": "stream",
  1221. "text": [
  1222. "k2=5.0 \n",
  1223. " mean precision: 0.28844444444444456\n",
  1224. "mean recall: 0.41883657325123097\n",
  1225. "mean F-measure: 0.34162116517157276\n",
  1226. "MAP@10: 0.3458602292768961\n",
  1227. "\n"
  1228. ]
  1229. },
  1230. {
  1231. "name": "stderr",
  1232. "output_type": "stream",
  1233. "text": [
  1234. "225it [05:03, 1.35s/it]\n",
  1235. "0it [00:00, ?it/s]"
  1236. ]
  1237. },
  1238. {
  1239. "name": "stdout",
  1240. "output_type": "stream",
  1241. "text": [
  1242. "k2=10.0 \n",
  1243. " mean precision: 0.2875555555555556\n",
  1244. "mean recall: 0.41788419229885004\n",
  1245. "mean F-measure: 0.340680891429387\n",
  1246. "MAP@10: 0.34389379776602\n",
  1247. "\n"
  1248. ]
  1249. },
  1250. {
  1251. "name": "stderr",
  1252. "output_type": "stream",
  1253. "text": [
  1254. "225it [05:04, 1.35s/it]\n",
  1255. "0it [00:00, ?it/s]"
  1256. ]
  1257. },
  1258. {
  1259. "name": "stdout",
  1260. "output_type": "stream",
  1261. "text": [
  1262. "k2=50.0 \n",
  1263. " mean precision: 0.2862222222222223\n",
  1264. "mean recall: 0.4152460156606736\n",
  1265. "mean F-measure: 0.33886819374753546\n",
  1266. "MAP@10: 0.3422005563953977\n",
  1267. "\n"
  1268. ]
  1269. },
  1270. {
  1271. "name": "stderr",
  1272. "output_type": "stream",
  1273. "text": [
  1274. "225it [05:03, 1.35s/it]\n",
  1275. "0it [00:00, ?it/s]"
  1276. ]
  1277. },
  1278. {
  1279. "name": "stdout",
  1280. "output_type": "stream",
  1281. "text": [
  1282. "k2=100.0 \n",
  1283. " mean precision: 0.2862222222222223\n",
  1284. "mean recall: 0.4152460156606736\n",
  1285. "mean F-measure: 0.33886819374753546\n",
  1286. "MAP@10: 0.3421194276476022\n",
  1287. "\n"
  1288. ]
  1289. },
  1290. {
  1291. "name": "stderr",
  1292. "output_type": "stream",
  1293. "text": [
  1294. "225it [04:59, 1.33s/it]\n",
  1295. "0it [00:00, ?it/s]"
  1296. ]
  1297. },
  1298. {
  1299. "name": "stdout",
  1300. "output_type": "stream",
  1301. "text": [
  1302. "k2=500.0 \n",
  1303. " mean precision: 0.2862222222222223\n",
  1304. "mean recall: 0.4152460156606736\n",
  1305. "mean F-measure: 0.33886819374753546\n",
  1306. "MAP@10: 0.3419881099353322\n",
  1307. "\n"
  1308. ]
  1309. },
  1310. {
  1311. "name": "stderr",
  1312. "output_type": "stream",
  1313. "text": [
  1314. "225it [08:09, 2.18s/it]\n"
  1315. ]
  1316. },
  1317. {
  1318. "name": "stdout",
  1319. "output_type": "stream",
  1320. "text": [
  1321. "k2=1000.0 \n",
  1322. " mean precision: 0.2862222222222223\n",
  1323. "mean recall: 0.4152460156606736\n",
  1324. "mean F-measure: 0.33886819374753546\n",
  1325. "MAP@10: 0.3419881099353322\n",
  1326. "\n"
  1327. ]
  1328. }
  1329. ],
  1330. "source": [
  1331. "texts_gen = get_texts_gen(TEXTS_FILE)\n",
  1332. "text_tokens_gen = get_item_tokens_gen(texts_gen, \"w\")\n",
  1333. "inv_index = InvIndex(text_tokens_gen)\n",
  1334. "\n",
  1335. "for k2 in [0., 1., 5., 10., 50., 100., 500., 1000.]:\n",
  1336. " queries_gen = get_queries_gen(QUERIES_FILE)\n",
  1337. " query_tokens_gen = get_item_tokens_gen(queries_gen, \"w\")\n",
  1338. "\n",
  1339. " rsv = RSVFullRankedList(k1=1.2, k2=k2, b=0.75)\n",
  1340. " with open(PREDICTION_FILE, \"w\") as fout:\n",
  1341. " for query_id, query in tqdm(enumerate(query_tokens_gen)):\n",
  1342. " ranked_list = rsv(query, inv_index)\n",
  1343. " for doc_id, _ in ranked_list:\n",
  1344. " fout.write(\"{} {}\\n\".format(query_id + 1, doc_id + 1))\n",
  1345. " \n",
  1346. " with subprocess.Popen(\"cd data && python3 eval.py\", shell=True, stdout=subprocess.PIPE) as p:\n",
  1347. " print(\"k2={} \\n {}\".format(k2, p.stdout.read().decode()))"
  1348. ]
  1349. },
  1350. {
  1351. "cell_type": "code",
  1352. "execution_count": null,
  1353. "metadata": {
  1354. "collapsed": true
  1355. },
  1356. "outputs": [],
  1357. "source": []
  1358. }
  1359. ],
  1360. "metadata": {
  1361. "anaconda-cloud": {},
  1362. "kernelspec": {
  1363. "display_name": "Python [default]",
  1364. "language": "python",
  1365. "name": "python3"
  1366. },
  1367. "language_info": {
  1368. "codemirror_mode": {
  1369. "name": "ipython",
  1370. "version": 3
  1371. },
  1372. "file_extension": ".py",
  1373. "mimetype": "text/x-python",
  1374. "name": "python",
  1375. "nbconvert_exporter": "python",
  1376. "pygments_lexer": "ipython3",
  1377. "version": "3.5.2"
  1378. }
  1379. },
  1380. "nbformat": 4,
  1381. "nbformat_minor": 2
  1382. }
Add Comment
Please, Sign In to add comment