Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- {
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Building Machine Learning Classifiers: Model selection"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Read in & clean text"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "import nltk\n",
- "import pandas as pd\n",
- "import re\n",
- "from sklearn.feature_extraction.text import TfidfVectorizer\n",
- "import string\n",
- "\n",
- "stopwords = nltk.corpus.stopwords.words('english')\n",
- "ps = nltk.PorterStemmer()\n",
- "\n",
- "data = pd.read_csv(\"SMSSpamCollection.tsv\", sep='\\t')\n",
- "data.columns = ['label', 'body_text']\n",
- "\n",
- "def count_punct(text):\n",
- " count = sum([1 for char in text if char in string.punctuation])\n",
- " return round(count/(len(text) - text.count(\" \")), 3)*100\n",
- "\n",
- "data['body_len'] = data['body_text'].apply(lambda x: len(x) - x.count(\" \"))\n",
- "data['punct%'] = data['body_text'].apply(lambda x: count_punct(x))\n",
- "\n",
- "def clean_text(text):\n",
- " text = \"\".join([word.lower() for word in text if word not in string.punctuation])\n",
- " tokens = re.split('\\W+', text)\n",
- " text = [ps.stem(word) for word in tokens if word not in stopwords]\n",
- " return text"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Split into train/test"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "from sklearn.model_selection import train_test_split\n",
- "\n",
- "X_train, X_test, y_train, y_test = train_test_split(data[['body_text', 'body_len', 'punct%']], data['label'], test_size=0.2)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Vectorize text"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "<div>\n",
- "<style scoped>\n",
- " .dataframe tbody tr th:only-of-type {\n",
- " vertical-align: middle;\n",
- " }\n",
- "\n",
- " .dataframe tbody tr th {\n",
- " vertical-align: top;\n",
- " }\n",
- "\n",
- " .dataframe thead th {\n",
- " text-align: right;\n",
- " }\n",
- "</style>\n",
- "<table border=\"1\" class=\"dataframe\">\n",
- " <thead>\n",
- " <tr style=\"text-align: right;\">\n",
- " <th></th>\n",
- " <th>body_len</th>\n",
- " <th>punct%</th>\n",
- " <th>0</th>\n",
- " <th>1</th>\n",
- " <th>2</th>\n",
- " <th>3</th>\n",
- " <th>4</th>\n",
- " <th>5</th>\n",
- " <th>6</th>\n",
- " <th>7</th>\n",
- " <th>...</th>\n",
- " <th>7112</th>\n",
- " <th>7113</th>\n",
- " <th>7114</th>\n",
- " <th>7115</th>\n",
- " <th>7116</th>\n",
- " <th>7117</th>\n",
- " <th>7118</th>\n",
- " <th>7119</th>\n",
- " <th>7120</th>\n",
- " <th>7121</th>\n",
- " </tr>\n",
- " </thead>\n",
- " <tbody>\n",
- " <tr>\n",
- " <th>0</th>\n",
- " <td>328</td>\n",
- " <td>8.5</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>...</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.000000</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <th>1</th>\n",
- " <td>25</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>...</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.000000</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <th>2</th>\n",
- " <td>112</td>\n",
- " <td>2.7</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>...</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.000000</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <th>3</th>\n",
- " <td>24</td>\n",
- " <td>8.3</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>...</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.645353</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <th>4</th>\n",
- " <td>94</td>\n",
- " <td>3.2</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>...</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " <td>0.000000</td>\n",
- " <td>0.0</td>\n",
- " <td>0.0</td>\n",
- " </tr>\n",
- " </tbody>\n",
- "</table>\n",
- "<p>5 rows × 7124 columns</p>\n",
- "</div>"
- ],
- "text/plain": [
- " body_len punct% 0 1 2 3 4 5 6 7 ... 7112 7113 \\\n",
- "0 328 8.5 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
- "1 25 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
- "2 112 2.7 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
- "3 24 8.3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
- "4 94 3.2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 \n",
- "\n",
- " 7114 7115 7116 7117 7118 7119 7120 7121 \n",
- "0 0.0 0.0 0.0 0.0 0.0 0.000000 0.0 0.0 \n",
- "1 0.0 0.0 0.0 0.0 0.0 0.000000 0.0 0.0 \n",
- "2 0.0 0.0 0.0 0.0 0.0 0.000000 0.0 0.0 \n",
- "3 0.0 0.0 0.0 0.0 0.0 0.645353 0.0 0.0 \n",
- "4 0.0 0.0 0.0 0.0 0.0 0.000000 0.0 0.0 \n",
- "\n",
- "[5 rows x 7124 columns]"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "tfidf_vect = TfidfVectorizer(analyzer=clean_text)\n",
- "tfidf_vect_fit = tfidf_vect.fit(X_train['body_text'])\n",
- "\n",
- "tfidf_train = tfidf_vect_fit.transform(X_train['body_text'])\n",
- "tfidf_test = tfidf_vect_fit.transform(X_test['body_text'])\n",
- "\n",
- "X_train_vect = pd.concat([X_train[['body_len', 'punct%']].reset_index(drop=True), \n",
- " pd.DataFrame(tfidf_train.toarray())], axis=1)\n",
- "X_test_vect = pd.concat([X_test[['body_len', 'punct%']].reset_index(drop=True), \n",
- " pd.DataFrame(tfidf_test.toarray())], axis=1)\n",
- "\n",
- "X_train_vect.head()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Final evaluation of models"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [],
- "source": [
- "from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier\n",
- "from sklearn.metrics import precision_recall_fscore_support as score\n",
- "import time"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Fit Time: 1.924, Predict Time: 0.152, Precision: 1.0 / Recall: 0.839 / Accuracy: 0.978\n",
- "Precision: 1.0 / Recall: 0.839 / Accuracy: 0.978\n"
- ]
- }
- ],
- "source": [
- "rf = RandomForestClassifier(n_estimators=150, max_depth=None, n_jobs=-1)\n",
- "\n",
- "rf_model = rf.fit(X_train_vect, y_train)\n",
- "y_pred = rf_model.predict(X_test_vect)\n",
- "rf = RandomForestClassifier(n_estimators=150, max_depth=None, n_jobs=-1)\n",
- "\n",
- "start = time.time()\n",
- "rf_model = rf.fit(X_train_vect, y_train)\n",
- "end = time.time()\n",
- "fit_time = (end - start)\n",
- "\n",
- "start = time.time()\n",
- "y_pred = rf_model.predict(X_test_vect)\n",
- "end = time.time()\n",
- "pred_time = (end - start)\n",
- "\n",
- "precision, recall, fscore, train_support = score(y_test, y_pred, pos_label='spam', average='binary')\n",
- "print('Fit Time: {}, Predict Time: {}, Precision: {} / Recall: {} / Accuracy: {}'.format(\n",
- " round(fit_time, 3), round(pred_time, 3), round(precision, 3), round(recall, 3), round((y_pred==y_test).sum()/len(y_pred), 3)))\n",
- "# precision, recall, fscore, train_support = score(y_test, y_pred, pos_label='spam', average='binary')\n",
- "# print('Precision: {} / Recall: {} / Accuracy: {}'.format(\n",
- "# round(precision, 3), round(recall, 3), round((y_pred==y_test).sum()/len(y_pred), 3)))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Fit Time: 218.932, Predict Time: 0.142, Precision: 0.915 / Recall: 0.866 / Accuracy: 0.971\n"
- ]
- }
- ],
- "source": [
- "gb = GradientBoostingClassifier(n_estimators=150, max_depth=11)\n",
- "\n",
- "start = time.time()\n",
- "gb_model = gb.fit(X_train_vect, y_train)\n",
- "end = time.time()\n",
- "fit_time = (end - start)\n",
- "\n",
- "start = time.time()\n",
- "y_pred = gb_model.predict(X_test_vect)\n",
- "end = time.time()\n",
- "pred_time = (end - start)\n",
- "\n",
- "precision, recall, fscore, train_support = score(y_test, y_pred, pos_label='spam', average='binary')\n",
- "print('Fit Time: {}, Predict Time: {}, Precision: {} / Recall: {} / Accuracy: {}'.format(\n",
- " round(fit_time, 3), round(pred_time, 3), round(precision, 3), round(recall, 3), round((y_pred==y_test).sum()/len(y_pred), 3)))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.6.9"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement