Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- {
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Необходимо положить архив `data.tar.gz` в директорию `data`, распаковать, после чего: \n",
- "\n",
- "* изменить имя файла с правильными ответами на `test.qrel_clean`;\n",
- "* создать файл для предсказаний с названием `train.qrel_clean`;\n",
- "* поменять названия файлов в `eval.py`."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "collapsed": false
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "data/\n",
- "data/eval.py\n",
- "data/cran.qry\n",
- "data/qrel_clean\n",
- "data/cran.all.1400\n"
- ]
- }
- ],
- "source": [
- "!tar -xvzf data/data.tar.gz"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {
- "collapsed": false
- },
- "outputs": [],
- "source": [
- "import sys\n",
- "import os\n",
- "import abc\n",
- "import math\n",
- "import string\n",
- "import functools\n",
- "import subprocess\n",
- "\n",
- "import numpy as np\n",
- "import pandas as pd\n",
- "from pandas import DataFrame\n",
- "\n",
- "from itertools import chain, product\n",
- "from operator import itemgetter\n",
- "\n",
- "import nltk\n",
- "from nltk.tokenize import word_tokenize\n",
- "from nltk.stem import WordNetLemmatizer\n",
- "from nltk.stem import PorterStemmer\n",
- "from nltk.corpus import stopwords\n",
- "\n",
- "\n",
- "from typing import (\n",
- " Optional,\n",
- " Generator,\n",
- " List\n",
- ")\n",
- "\n",
- "from tqdm import tqdm\n",
- "\n",
- "stop_words = set(stopwords.words('english'))\n",
- "lemmatizer = WordNetLemmatizer()\n",
- "stemmer = PorterStemmer()\n",
- "\n",
- "DIRECTORY = \"data\"\n",
- "TEXTS_FILE = os.path.join(DIRECTORY, \"cran.all.1400\")\n",
- "QUERIES_FILE = os.path.join(DIRECTORY, \"cran.qry\")\n",
- "PREDICTION_FILE = os.path.join(DIRECTORY, \"train.qrel_clean\")\n",
- "CORRECT_ANSWERS_FILE = os.path.join(DIRECTORY, \"test.qrel_clean\")\n",
- "NUMBER_TEXTS = 1400\n",
- "NUMBER_QUERIES = 225"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 1.Парсинг"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "class Text:\n",
- " __slots__ = [\"i\", \"t\", \"a\", \"b\", \"w\"]\n",
- " \n",
- " def __init__(self):\n",
- " self.i = None # type: str\n",
- " self.t = None # type: str\n",
- " self.a = None # type: str\n",
- " self.b = None # type: str\n",
- " self.w = None # type: str\n",
- "\n",
- " \n",
- "class Query:\n",
- " __slots__ = [\"i\", \"w\"]\n",
- " \n",
- " def __init__(self):\n",
- " self.i = None # type: str\n",
- " self.w = None # type: str\n",
- "\n",
- "\n",
- "def _read_file(filepath):\n",
- " # type: (str) -> str\n",
- " \n",
- " file = open(filepath)\n",
- " yield from file\n",
- " \n",
- "\n",
- "def _parse(gen, cls, line_starts):\n",
- " # type: (Generator[str]) -> Text\n",
- " \n",
- " def set_current_state(line):\n",
- " nonlocal current_state\n",
- " \n",
- " for i, (s, _) in enumerate(line_starts):\n",
- " if line.startswith(s):\n",
- " current_state = i\n",
- " \n",
- " def yield_text():\n",
- " t = cls()\n",
- " for i, (_, s) in enumerate(line_starts):\n",
- " setattr(t, s, ''.join(text_lists[i]))\n",
- " \n",
- " return t\n",
- " \n",
- " text_lists = [[] for _ in range(len(line_starts))]\n",
- " current_state = -1\n",
- "\n",
- " for line in chain(gen, ['.I']):\n",
- " set_current_state(line)\n",
- " \n",
- " if current_state == 0:\n",
- " if any(text_lists):\n",
- " yield yield_text()\n",
- " text_lists = [[] for _ in range(len(line_starts))]\n",
- " \n",
- " text_lists[current_state].append(line)\n",
- "\n",
- "\n",
- "def get_texts_gen(filepath):\n",
- " # type: (str) -> Generator[Text]\n",
- "\n",
- " gen = _read_file(filepath)\n",
- " texts = _parse(gen, Text, [('.I', 'i'), ('.T', 't'), ('.A', 'a'), ('.B', 'b'), ('.W', 'w')])\n",
- " return texts\n",
- "\n",
- "\n",
- "def get_queries_gen(filepath):\n",
- " # type: (str) -> Generator[Query]\n",
- " \n",
- " gen = _read_file(filepath)\n",
- " queries = _parse(gen, Query, [('.I', 'i'), ('.W', 'w')])\n",
- " return queries"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Нормализация\n",
- "\n",
- "Используем библиотеку `nltk`. \n",
- "\n",
- "* Удалим из текста все цифры и символы пунктуации. \n",
- "* Разобъём на токены, удалим стоп слова и односимвольные токены.\n",
- "* Прведём стемминг и лемматизацию."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "class ItemTokens:\n",
- " __slots__ = [\"i\", \"tokens\"]\n",
- " \n",
- " def __init__(self, text, attr):\n",
- " self.i = text.i\n",
- " \n",
- " assert hasattr(text, attr)\n",
- " \n",
- " self.tokens = self._filter(self._clean(getattr(text, attr)))\n",
- " \n",
- " def _filter(self, s):\n",
- " stop_tokens = [\"I\", \"T\", \"A\", \"B\", \"W\"]\n",
- " return (stemmer.stem(lemmatizer.lemmatize(t)) for t in word_tokenize(s) \n",
- " if (t not in stop_words) and (t not in stop_tokens) and (len(t) >= 2))\n",
- " \n",
- " def _clean(self, s):\n",
- " for t in chain(string.punctuation, string.digits):\n",
- " s = s.replace(t, \" \")\n",
- " return s\n",
- " \n",
- " def __iter__(self):\n",
- " return self.tokens\n",
- "\n",
- " \n",
- "def get_item_tokens_gen(texts, attr):\n",
- " # type: (List[Text], str) -> Generator[TextTokens]\n",
- " \n",
- " item_tokens = (ItemTokens(text, attr) for text in texts)\n",
- " return item_tokens"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 2. Инвертированный индекс"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "class InvIndex:\n",
- " def __init__(self, text_tokens_gen):\n",
- " columns = [\"doc_id\", \"token\", \"count\"]\n",
- " index = [\"doc_id\", \"token\"]\n",
- " \n",
- " def get_part_df():\n",
- " for doc_id, text_tokens in tqdm(enumerate(text_tokens_gen)):\n",
- " data = [(doc_id, token, 1) for token in text_tokens]\n",
- " df = DataFrame(data, columns=columns).groupby(by=index).sum()\n",
- " yield df\n",
- " \n",
- " self.df = pd.concat(get_part_df())\n",
- " self.df[\"count\"] = self.df[\"count\"].astype(np.float32)\n",
- " \n",
- " @functools.lru_cache(maxsize=256, typed=False)\n",
- " def get_l(self, doc_id=None):\n",
- " try:\n",
- " if doc_id is not None:\n",
- " return self.df.loc[doc_id, :][\"count\"].sum()\n",
- " \n",
- " return self.df.reset_index().groupby(\"doc_id\").sum()[\"count\"].mean()\n",
- " except KeyError:\n",
- " pass\n",
- " \n",
- " @functools.lru_cache(maxsize=256, typed=False)\n",
- " def get_f(self, t, doc_id=None):\n",
- " try:\n",
- " if doc_id is not None:\n",
- " return self.df.loc[(doc_id, t), :]\n",
- " \n",
- " return self.df.loc[(slice(None), t), :]\n",
- " except KeyError:\n",
- " pass\n",
- " \n",
- " @functools.lru_cache(maxsize=256, typed=False)\n",
- " def get_n(self, t=None):\n",
- " try:\n",
- " if t is None:\n",
- " return self.df[\"count\"].sum()\n",
- "\n",
- " return self.df.loc[(slice(None), t), :][\"count\"].sum()\n",
- " except KeyError:\n",
- " return 0 "
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 3. RSV\n",
- "\n",
- "#### RSVRankedList\n",
- "\n",
- "Формула из условия.\n",
- "\n",
- "#### RSV2RankedLIst\n",
- "\n",
- "В условии $IDF = \\log{(1 + \\frac{N - N_t + 0.5}{N_t + 0.5})}$, что по сути равно формуле из лекции $IDF = \\log{\\frac{N} {N_t}}$. Будем считать $IDF$ по формуле $\\log{\\frac{N - N_t}{N_t}}$.\n",
- "\n",
- "#### RSVNormRankedList\n",
- "\n",
- "В одном из заданий предлагалось пронормировать $RSV$ на сумму $IDF$ токенов запроса. Т.к. формула $IDF$ не зависит от документа, то по сути мы просто поделим $RSV(q, d)$ на константы, и список релевантных документов для запроса не изменится.\n",
- "\n",
- "#### RSVFullRankedList\n",
- "\n",
- "Полная формула из лекции."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {
- "collapsed": false
- },
- "outputs": [],
- "source": [
- "class RankedList:\n",
- " __metaclass__ = abc.ABCMeta\n",
- "\n",
- " @abc.abstractmethod\n",
- " def __call__(self, q, inv_index):\n",
- " \"\"\"Return sorted list of docs relevant to the query\"\"\"\n",
- "\n",
- "\n",
- "class RSVRankedList(RankedList):\n",
- " \"\"\"Formula from the task\"\"\"\n",
- " def __init__(self, k1, b):\n",
- " self.k1 = k1\n",
- " self.b = b\n",
- " \n",
- " def __call__(self, q, inv_index):\n",
- " rsv = {}\n",
- " N = inv_index.get_n()\n",
- " \n",
- " for t in q:\n",
- " Nt = inv_index.get_n(t)\n",
- " F = inv_index.get_f(t)\n",
- " idf = math.log(1.0 + (N - Nt + 0.5) / (Nt + 0.5))\n",
- " \n",
- " if F is not None:\n",
- " for index, row in F.iterrows():\n",
- " doc_id, ftd = index[0], row[\"count\"]\n",
- " Ld, L = inv_index.get_l(doc_id), inv_index.get_l()\n",
- " tf = ftd * (self.k1 + 1.) / (self.k1 * ((1. - self.b) + self.b * Ld / L) + ftd)\n",
- " rsv[doc_id] = rsv.get(doc_id, 0) + idf * tf\n",
- " return sorted(rsv.items(), key=itemgetter(1), reverse=True)[:10]\n",
- "\n",
- " \n",
- "class RSV2RankedList(RankedList):\n",
- " \"\"\"Formula from the lecture\"\"\"\n",
- " def __init__(self, k1, b):\n",
- " self.k1 = k1\n",
- " self.b = b\n",
- " \n",
- " def __call__(self, q, inv_index):\n",
- " rsv = {}\n",
- " N = inv_index.get_n()\n",
- " \n",
- " for t in q:\n",
- " Nt = inv_index.get_n(t)\n",
- " F = inv_index.get_f(t)\n",
- " idf = math.log((N - Nt + 0.5) / (Nt + 0.5))\n",
- " \n",
- " if F is not None:\n",
- " for index, row in F.iterrows():\n",
- " doc_id, ftd = index[0], row[\"count\"]\n",
- " Ld, L = inv_index.get_l(doc_id), inv_index.get_l()\n",
- " tf = ftd * (self.k1 + 1.) / (self.k1 * ((1. - self.b) + self.b * Ld / L) + ftd)\n",
- " rsv[doc_id] = rsv.get(doc_id, 0) + idf * tf\n",
- " return sorted(rsv.items(), key=itemgetter(1), reverse=True)[:10]\n",
- "\n",
- "\n",
- "class RSVNormRankedList(RankedList):\n",
- " \"\"\"Normed over sum of IDF's of terms\"\"\"\n",
- " def __init__(self, k1, b):\n",
- " self.k1 = k1\n",
- " self.b = b\n",
- " \n",
- " def __call__(self, q, inv_index):\n",
- " rsv = {}\n",
- " N = inv_index.get_n()\n",
- " idf_sum = 0.0\n",
- " \n",
- " for t in q:\n",
- " Nt = inv_index.get_n(t)\n",
- " F = inv_index.get_f(t)\n",
- " idf = math.log(1.0 + (N - Nt + 0.5) / (Nt + 0.5))\n",
- " idf_sum += idf\n",
- " \n",
- " if F is not None:\n",
- " for index, row in F.iterrows():\n",
- " doc_id, ftd = index[0], row[\"count\"]\n",
- " Ld, L = inv_index.get_l(doc_id), inv_index.get_l()\n",
- " tf = ftd * (self.k1 + 1.) / (self.k1 * ((1. - self.b) + self.b * Ld / L) + ftd)\n",
- " rsv[doc_id] = rsv.get(doc_id, 0) + idf * tf\n",
- " \n",
- " for doc_id in rsv.keys():\n",
- " rsv[doc_id] /= idf_sum\n",
- "\n",
- " return sorted(rsv.items(), key=itemgetter(1), reverse=True)[:10]\n",
- "\n",
- "\n",
- "class RSVFullRankedList(RankedList):\n",
- " \"\"\"Full formula\"\"\"\n",
- " def __init__(self, k1, k2, b):\n",
- " self.k1 = k1\n",
- " self.k2 = k2\n",
- " self.b = b\n",
- " \n",
- " def __call__(self, q, inv_index):\n",
- " rsv = {}\n",
- " N = inv_index.get_n()\n",
- " \n",
- " q, ftq = list(q), {}\n",
- " for t in q:\n",
- " ftq[t] = ftq.get(t, 0) + 1\n",
- " \n",
- " for t in q:\n",
- " Nt = inv_index.get_n(t)\n",
- " F = inv_index.get_f(t)\n",
- " idf = math.log(1.0 + (N - Nt + 0.5) / (Nt + 0.5))\n",
- " tf_q = (self.k2 + 1) * ftq[t] / (self.k2 + ftq[t])\n",
- " \n",
- " if F is not None:\n",
- " for index, row in F.iterrows():\n",
- " doc_id, ftd = index[0], row[\"count\"]\n",
- " Ld, L = inv_index.get_l(doc_id), inv_index.get_l()\n",
- " tf = ftd * (self.k1 + 1.) / (self.k1 * ((1. - self.b) + self.b * Ld / L) + ftd)\n",
- " rsv[doc_id] = rsv.get(doc_id, 0) + idf * tf * tf_q\n",
- " return sorted(rsv.items(), key=itemgetter(1), reverse=True)[:10]\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 4. Оценка качества поиска\n",
- "\n",
- "По значениям метрик можно понять, что среди предсказанных нами документов действительно релевантными оказалось $29.02\\%$, но это составило $42.46\\%$ от числа всех релевантных документов.\n",
- "\n",
- "С одной стороны, чем реже встречается токен, тем больше $IDF$, с другой, чем больше встречается токен в документе, тем больше $TF$. Если же слово часто встречается во многих документах, то оно будет вносить примерно одинаковый вклад во все соответствующие $RSV$, а слишком редко встречающееся слово, может не попастаться в самом запросе, и скорее всего просто содержит ошибку. Наилучшими являются слова, которые очень часто встречаются в некоторых документах, тем самым они будут вносить существенный вклад в $RSV$ этих документов.\n",
- "\n",
- "Всего уникальных токенов $4270$. Среднее количество вхождений одного токена - $29.97$. Самый популярный токен - flow - встречается $2080$ раз, а $50\\%$ токенов встречаются не более $3$-х раз. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {
- "collapsed": false
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "1400it [00:12, 111.30it/s]\n",
- "225it [07:23, 3.39s/it]\n"
- ]
- }
- ],
- "source": [
- "texts_gen = get_texts_gen(TEXTS_FILE)\n",
- "text_tokens_gen = get_item_tokens_gen(texts_gen, \"w\")\n",
- "\n",
- "queries_gen = get_queries_gen(QUERIES_FILE)\n",
- "query_tokens_gen = get_item_tokens_gen(queries_gen, \"w\")\n",
- "\n",
- "rsv = RSVRankedList(k1=1.2, b=0.75)\n",
- "inv_index = InvIndex(text_tokens_gen)\n",
- "with open(PREDICTION_FILE, \"w\") as fout:\n",
- " for query_id, query in tqdm(enumerate(query_tokens_gen)):\n",
- " ranked_list = rsv(query, inv_index)\n",
- " for doc_id, _ in ranked_list:\n",
- " fout.write(\"{} {}\\n\".format(query_id + 1, doc_id + 1))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {
- "collapsed": false
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "mean precision: 0.2902222222222223\n",
- "mean recall: 0.42462172636971746\n",
- "mean F-measure: 0.344787589721076\n",
- "MAP@10: 0.35587606100053193\n"
- ]
- }
- ],
- "source": [
- "!cd data && python3 eval.py"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {
- "collapsed": false
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "token\n",
- "flow 2080.0\n",
- "pressur 1390.0\n",
- "number 1345.0\n",
- "boundari 1214.0\n",
- "layer 1161.0\n",
- "result 1087.0\n",
- "effect 997.0\n",
- "method 888.0\n",
- "theori 883.0\n",
- "bodi 854.0\n",
- "solut 849.0\n",
- "heat 844.0\n",
- "wing 838.0\n",
- "mach 822.0\n",
- "equat 777.0\n",
- "Name: count, dtype: float32"
- ]
- },
- "execution_count": 9,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "stat = inv_index.df.reset_index().groupby(\"token\").sum().sort_values(by=\"count\", ascending=False)[\"count\"]\n",
- "stat.head(15)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {
- "collapsed": false
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "count 4270.000000\n",
- "mean 29.973770\n",
- "std 95.428993\n",
- "min 1.000000\n",
- "25% 1.000000\n",
- "50% 3.000000\n",
- "75% 16.000000\n",
- "max 2080.000000\n",
- "Name: count, dtype: float64"
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "stat.describe()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 5. Сравнение значений метрик для полных текстов и заголовков\n",
- "\n",
- "Аннотации содержат больше токенов чем заголовок, а значит больше шанс, что они содержат слова, специфичные для текста. При этом они содержат больше \"общезначимых\" слов, но у них будет меньше $tf-idf$, а значит и влияние на $RSV$.\n",
- "\n",
- "| | precision | recall | F-measure | MAP@10 |\n",
- "|-----------|-----------|--------|-----------|--------|\n",
- "| аннотации | 0.2902 | 0.4246 | 0.3448 | 0.3558 |\n",
- "| заголовки | 0.2577 | 0.3735 | 0.3050 | 0.2982 |"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {
- "collapsed": false
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "1400it [00:05, 258.37it/s]\n",
- "225it [00:57, 3.90it/s]\n"
- ]
- }
- ],
- "source": [
- "texts_gen = get_texts_gen(TEXTS_FILE)\n",
- "text_tokens_gen = get_item_tokens_gen(texts_gen, \"t\")\n",
- "\n",
- "queries_gen = get_queries_gen(QUERIES_FILE)\n",
- "query_tokens_gen = get_item_tokens_gen(queries_gen, \"w\")\n",
- "\n",
- "rsv = RSVRankedList(k1=1.2, b=0.75)\n",
- "inv_index = InvIndex(text_tokens_gen)\n",
- "with open(PREDICTION_FILE, \"w\") as fout:\n",
- " for query_id, query in tqdm(enumerate(query_tokens_gen)):\n",
- " ranked_list = rsv(query, inv_index)\n",
- " for doc_id, _ in ranked_list:\n",
- " fout.write(\"{} {}\\n\".format(query_id + 1, doc_id + 1))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {
- "collapsed": false
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "mean precision: 0.25777777777777794\n",
- "mean recall: 0.3735774885450877\n",
- "mean F-measure: 0.3050579601111924\n",
- "MAP@10: 0.29823958945718215\n"
- ]
- }
- ],
- "source": [
- "!cd data && python3 eval.py"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "collapsed": true
- },
- "source": [
- "### 6. Подбор параметров\n",
- "\n",
- "Если бы в $TF$ учитывались только $f(t, d)$ [например, $\\frac{(k1 + 1)f_{t, d}}{k1 + f_{t, d}}$)], то чем длинее был бы текст, тем больше это значение могло бы быть для токена. Поэтому $\\frac{L_d}{\\overline{L}}$ призван уравновесить шансы длинных и коротких текстов, а параметр $b$ обозначает насколько важно нам учитывать этот факт. А $k_1$ обозначает что нам важнее, просто количество вхождений токена или количество вхождений токена при учёте длины текста.\n",
- "\n",
- "При фиксированном $k_1$ значения метрик увеличиваются при росте $b$ до $\\approx 0.7-0.8$, и уменьшаются, при дальнейшем увеличении $b$. И оптимальными кажутся два пары $k_1$ и $b$ (см. таблица).\n",
- "\n",
- "| k1 | b | precision | recall | F-score | MAP@10 |\n",
- "|-----|-----|-----------|--------|---------|--------|\n",
- "| 1.2 | 0.0 | 0.2502 | 0.3717 | 0.2991 | 0.2866 |\n",
- "| 1.2 | 0.3 | 0.2724 | 0.4025 | 0.3249 | 0.3252 |\n",
- "| <b> 1.2 | <b> 0.7 | <b> 0.2876 | <b> 0.4201 | <b>0.3414 | <b>0.3544 |\n",
- "| 1.2 | 1.0 | 0.2898 | 0.4234 | 0.3440 | 0.3547 |\n",
- "| 1.4 | 0.0 | 0.2507 | 0.3731 | 0.2999 | 0.2858 |\n",
- "| 1.4 | 0.3 | 0.2733 | 0.4033 | 0.3258 | 0.3277 |\n",
- "| 1.4 | 0.7 | 0.2942 | 0.4309 | 0.3497 | 0.3602 |\n",
- "| 1.4 | 1.0 | 0.2916 | 0.4260 | 0.3462 | 0.3560 |\n",
- "| 1.7 | 0.0 | 0.2498 | 0.3693 | 0.2980 | 0.2835 |\n",
- "| 1.7 | 0.3 | 0.2791 | 0.4122 | 0.3328 | 0.3322 |\n",
- "| 1.7 | 0.7 | 0.2951 | 0.4322 | 0.3507 | 0.3619 |\n",
- "| 1.7 | 1.0 | 0.2924 | 0.4273 | 0.3472 | 0.3472 |\n",
- "| 2.0 | 0.0 | 0.2498 | 0.3691 | 0.2979 | 0.2829 |\n",
- "| 2.0 | 0.3 | 0.2818 | 0.4139 | 0.3353 | 0.3346 |\n",
- "| <b> 2.0 | <b> 0.7 | <b> 0.2969 | <b> 0.4341 | <b> 0.3526 | <b> 0.3628 |\n",
- "| 2.0 | 1.0 | 0.2947 | 0.4322 | 0.3504 | 0.3573 |"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {
- "collapsed": false
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "1400it [00:10, 128.92it/s]\n",
- "225it [04:28, 1.19s/it]\n",
- "0it [00:00, ?it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k1=1.2 b=0.0\n",
- " mean precision: 0.25022222222222235\n",
- "mean recall: 0.37170986708272963\n",
- "mean F-measure: 0.2991004019982698\n",
- "MAP@10: 0.28660730032753856\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "225it [04:27, 1.19s/it]\n",
- "0it [00:00, ?it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k1=1.2 b=0.3\n",
- " mean precision: 0.27244444444444454\n",
- "mean recall: 0.4024949639631608\n",
- "mean F-measure: 0.32494032940629947\n",
- "MAP@10: 0.32524341283838654\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "225it [04:28, 1.19s/it]\n",
- "0it [00:00, ?it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k1=1.2 b=0.7\n",
- " mean precision: 0.2875555555555556\n",
- "mean recall: 0.4200675028154939\n",
- "mean F-measure: 0.3414042058522218\n",
- "MAP@10: 0.3544446656028666\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "225it [04:28, 1.19s/it]\n",
- "0it [00:00, ?it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k1=1.2 b=1.0\n",
- " mean precision: 0.2897777777777779\n",
- "mean recall: 0.42342078912050435\n",
- "mean F-measure: 0.34407790769930774\n",
- "MAP@10: 0.3546792670977857\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "225it [04:32, 1.21s/it]\n",
- "0it [00:00, ?it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k1=1.4 b=0.0\n",
- " mean precision: 0.2506666666666668\n",
- "mean recall: 0.3731463493139485\n",
- "mean F-measure: 0.299882654466029\n",
- "MAP@10: 0.28581725035693295\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "225it [04:28, 1.19s/it]\n",
- "0it [00:00, ?it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k1=1.4 b=0.3\n",
- " mean precision: 0.27333333333333343\n",
- "mean recall: 0.4033278430437178\n",
- "mean F-measure: 0.3258438569079453\n",
- "MAP@10: 0.3277454074633971\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "225it [04:37, 1.23s/it]\n",
- "0it [00:00, ?it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k1=1.4 b=0.7\n",
- " mean precision: 0.2942222222222223\n",
- "mean recall: 0.4308948236428148\n",
- "mean F-measure: 0.3496782575425439\n",
- "MAP@10: 0.36022575165868814\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "225it [05:17, 1.41s/it]\n",
- "0it [00:00, ?it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k1=1.4 b=1.0\n",
- " mean precision: 0.29155555555555573\n",
- "mean recall: 0.4260349113816855\n",
- "mean F-measure: 0.34619424587426306\n",
- "MAP@10: 0.35597095266089984\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "225it [05:08, 1.37s/it]\n",
- "0it [00:00, ?it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k1=1.7 b=0.0\n",
- " mean precision: 0.24977777777777785\n",
- "mean recall: 0.36934844968271546\n",
- "mean F-measure: 0.29801688539612997\n",
- "MAP@10: 0.2835236639511774\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "225it [05:01, 1.34s/it]\n",
- "0it [00:00, ?it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k1=1.7 b=0.3\n",
- " mean precision: 0.2791111111111111\n",
- "mean recall: 0.41221055909310067\n",
- "mean F-measure: 0.33284808539625815\n",
- "MAP@10: 0.33223726029506456\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "225it [05:02, 1.34s/it]\n",
- "0it [00:00, ?it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k1=1.7 b=0.7\n",
- " mean precision: 0.2951111111111112\n",
- "mean recall: 0.4322190950141451\n",
- "mean F-measure: 0.35074208742843543\n",
- "MAP@10: 0.36186320861677995\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "225it [05:28, 1.46s/it]\n",
- "0it [00:00, ?it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k1=1.7 b=1.0\n",
- " mean precision: 0.2924444444444446\n",
- "mean recall: 0.42730186664864084\n",
- "mean F-measure: 0.3472391732369003\n",
- "MAP@10: 0.35618574298031985\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "225it [05:01, 1.34s/it]\n",
- "0it [00:00, ?it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k1=2.0 b=0.0\n",
- " mean precision: 0.24977777777777788\n",
- "mean recall: 0.3691234216243541\n",
- "mean F-measure: 0.297943607374251\n",
- "MAP@10: 0.2828864365499286\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "225it [05:18, 1.42s/it]\n",
- "0it [00:00, ?it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k1=2.0 b=0.3\n",
- " mean precision: 0.2817777777777778\n",
- "mean recall: 0.4138503477846135\n",
- "mean F-measure: 0.3352763554148436\n",
- "MAP@10: 0.33462170501945643\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "225it [05:07, 1.37s/it]\n",
- "0it [00:00, ?it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k1=2.0 b=0.7\n",
- " mean precision: 0.296888888888889\n",
- "mean recall: 0.4341141567425401\n",
- "mean F-measure: 0.35262143001032775\n",
- "MAP@10: 0.36280025965118545\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "225it [05:05, 1.36s/it]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k1=2.0 b=1.0\n",
- " mean precision: 0.29466666666666685\n",
- "mean recall: 0.4321510729978472\n",
- "mean F-measure: 0.3504056360415119\n",
- "MAP@10: 0.3572706279219507\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "\n"
- ]
- }
- ],
- "source": [
- "texts_gen = get_texts_gen(TEXTS_FILE)\n",
- "text_tokens_gen = get_item_tokens_gen(texts_gen, \"w\")\n",
- "inv_index = InvIndex(text_tokens_gen)\n",
- "\n",
- "for k1, b in product([1.2, 1.4, 1.7, 2.0], [0.0, 0.3, 0.7, 1.0]):\n",
- "\n",
- " queries_gen = get_queries_gen(QUERIES_FILE)\n",
- " query_tokens_gen = get_item_tokens_gen(queries_gen, \"w\")\n",
- "\n",
- " rsv = RSVRankedList(k1=k1, b=b)\n",
- " with open(PREDICTION_FILE, \"w\") as fout:\n",
- " for query_id, query in tqdm(enumerate(query_tokens_gen)):\n",
- " ranked_list = rsv(query, inv_index)\n",
- " for doc_id, _ in ranked_list:\n",
- " fout.write(\"{} {}\\n\".format(query_id + 1, doc_id + 1))\n",
- " \n",
- " with subprocess.Popen(\"cd data && python3 eval.py\", shell=True, stdout=subprocess.PIPE) as p:\n",
- " print(\"k1={} b={}\\n {}\".format(k1, b, p.stdout.read().decode()))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "collapsed": true
- },
- "source": [
- "### 7. Другая формула вычисления IDF"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {
- "collapsed": false
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "1400it [00:11, 126.09it/s]\n",
- "225it [04:47, 1.28s/it]\n"
- ]
- }
- ],
- "source": [
- "texts_gen = get_texts_gen(TEXTS_FILE)\n",
- "text_tokens_gen = get_item_tokens_gen(texts_gen, \"w\")\n",
- "\n",
- "queries_gen = get_queries_gen(QUERIES_FILE)\n",
- "query_tokens_gen = get_item_tokens_gen(queries_gen, \"w\")\n",
- "\n",
- "rsv = RSV2RankedList(k1=1.2, b=0.75)\n",
- "inv_index = InvIndex(text_tokens_gen)\n",
- "with open(PREDICTION_FILE, \"w\") as fout:\n",
- " for query_id, query in tqdm(enumerate(query_tokens_gen)):\n",
- " ranked_list = rsv(query, inv_index)\n",
- " for doc_id, _ in ranked_list:\n",
- " fout.write(\"{} {}\\n\".format(query_id + 1, doc_id + 1))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {
- "collapsed": false
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "mean precision: 0.29066666666666674\n",
- "mean recall: 0.42506617081416204\n",
- "mean F-measure: 0.34524772516567925\n",
- "MAP@10: 0.35587112272892696\n"
- ]
- }
- ],
- "source": [
- "!cd data && python3 eval.py"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 8. Нормировка\n",
- "\n",
- "Зачения метрик совпадают с изначальными. Нормировка ничего не изменила."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {
- "collapsed": false
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "1400it [00:10, 128.95it/s]\n",
- "225it [05:03, 1.35s/it]\n"
- ]
- }
- ],
- "source": [
- "texts_gen = get_texts_gen(TEXTS_FILE)\n",
- "text_tokens_gen = get_item_tokens_gen(texts_gen, \"w\")\n",
- "\n",
- "queries_gen = get_queries_gen(QUERIES_FILE)\n",
- "query_tokens_gen = get_item_tokens_gen(queries_gen, \"w\")\n",
- "\n",
- "rsv = RSVNormRankedList(k1=1.2, b=0.75)\n",
- "inv_index = InvIndex(text_tokens_gen)\n",
- "with open(PREDICTION_FILE, \"w\") as fout:\n",
- " for query_id, query in tqdm(enumerate(query_tokens_gen)):\n",
- " ranked_list = rsv(query, inv_index)\n",
- " for doc_id, _ in ranked_list:\n",
- " fout.write(\"{} {}\\n\".format(query_id + 1, doc_id + 1))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {
- "collapsed": false
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "mean precision: 0.2902222222222223\n",
- "mean recall: 0.42462172636971746\n",
- "mean F-measure: 0.344787589721076\n",
- "MAP@10: 0.35587606100053193\n"
- ]
- }
- ],
- "source": [
- "!cd data && python3 eval.py"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 9. Общая формула вычисления RSV\n",
- "\n",
- "Добавления множетеля $TF(t, q)$ - попытка больше учитывать те токены, которые часто встречаются в запросе. Но в нашем случае, это только ухудшает значения метрик. Оптимальное значение при $k2 = 0$."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {
- "collapsed": false
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "1400it [00:11, 120.43it/s]\n",
- "225it [04:58, 1.33s/it]\n",
- "0it [00:00, ?it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k2=0.0 \n",
- " mean precision: 0.2902222222222223\n",
- "mean recall: 0.42462172636971746\n",
- "mean F-measure: 0.344787589721076\n",
- "MAP@10: 0.35587606100053193\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "225it [05:02, 1.34s/it]\n",
- "0it [00:00, ?it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k2=1.0 \n",
- " mean precision: 0.28888888888888903\n",
- "mean recall: 0.4202494319974231\n",
- "mean F-measure: 0.3424025691184888\n",
- "MAP@10: 0.35060137174211237\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "225it [04:59, 1.33s/it]\n",
- "0it [00:00, ?it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k2=5.0 \n",
- " mean precision: 0.28844444444444456\n",
- "mean recall: 0.41883657325123097\n",
- "mean F-measure: 0.34162116517157276\n",
- "MAP@10: 0.3458602292768961\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "225it [05:03, 1.35s/it]\n",
- "0it [00:00, ?it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k2=10.0 \n",
- " mean precision: 0.2875555555555556\n",
- "mean recall: 0.41788419229885004\n",
- "mean F-measure: 0.340680891429387\n",
- "MAP@10: 0.34389379776602\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "225it [05:04, 1.35s/it]\n",
- "0it [00:00, ?it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k2=50.0 \n",
- " mean precision: 0.2862222222222223\n",
- "mean recall: 0.4152460156606736\n",
- "mean F-measure: 0.33886819374753546\n",
- "MAP@10: 0.3422005563953977\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "225it [05:03, 1.35s/it]\n",
- "0it [00:00, ?it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k2=100.0 \n",
- " mean precision: 0.2862222222222223\n",
- "mean recall: 0.4152460156606736\n",
- "mean F-measure: 0.33886819374753546\n",
- "MAP@10: 0.3421194276476022\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "225it [04:59, 1.33s/it]\n",
- "0it [00:00, ?it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k2=500.0 \n",
- " mean precision: 0.2862222222222223\n",
- "mean recall: 0.4152460156606736\n",
- "mean F-measure: 0.33886819374753546\n",
- "MAP@10: 0.3419881099353322\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "225it [08:09, 2.18s/it]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "k2=1000.0 \n",
- " mean precision: 0.2862222222222223\n",
- "mean recall: 0.4152460156606736\n",
- "mean F-measure: 0.33886819374753546\n",
- "MAP@10: 0.3419881099353322\n",
- "\n"
- ]
- }
- ],
- "source": [
- "texts_gen = get_texts_gen(TEXTS_FILE)\n",
- "text_tokens_gen = get_item_tokens_gen(texts_gen, \"w\")\n",
- "inv_index = InvIndex(text_tokens_gen)\n",
- "\n",
- "for k2 in [0., 1., 5., 10., 50., 100., 500., 1000.]:\n",
- " queries_gen = get_queries_gen(QUERIES_FILE)\n",
- " query_tokens_gen = get_item_tokens_gen(queries_gen, \"w\")\n",
- "\n",
- " rsv = RSVFullRankedList(k1=1.2, k2=k2, b=0.75)\n",
- " with open(PREDICTION_FILE, \"w\") as fout:\n",
- " for query_id, query in tqdm(enumerate(query_tokens_gen)):\n",
- " ranked_list = rsv(query, inv_index)\n",
- " for doc_id, _ in ranked_list:\n",
- " fout.write(\"{} {}\\n\".format(query_id + 1, doc_id + 1))\n",
- " \n",
- " with subprocess.Popen(\"cd data && python3 eval.py\", shell=True, stdout=subprocess.PIPE) as p:\n",
- " print(\"k2={} \\n {}\".format(k2, p.stdout.read().decode()))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "anaconda-cloud": {},
- "kernelspec": {
- "display_name": "Python [default]",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.5.2"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
Add Comment
Please, Sign In to add comment