Advertisement
Guest User

Untitled

a guest
Nov 9th, 2015
1,344
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 9.89 KB | None | 0 0
  1. # Author: Peter Prettenhofer <peter.prettenhofer@gmail.com>
  2. #         Olivier Grisel <olivier.grisel@ensta.org>
  3. #         Mathieu Blondel <mathieu@mblondel.org>
  4. #         Lars Buitinck <L.J.Buitinck@uva.nl>
  5. # License: BSD 3 clause
  6.  
  7. from __future__ import print_function
  8.  
  9. import logging
  10. import numpy as np
  11. from optparse import OptionParser
  12. import sys
  13. from time import time
  14. import matplotlib.pyplot as plt
  15.  
  16. from sklearn.datasets import fetch_20newsgroups
  17. from sklearn.feature_extraction.text import TfidfVectorizer
  18. from sklearn.feature_extraction.text import HashingVectorizer
  19. from sklearn.feature_selection import SelectKBest, chi2
  20. from sklearn.linear_model import RidgeClassifier
  21. from sklearn.pipeline import Pipeline
  22. from sklearn.svm import LinearSVC
  23. from sklearn.linear_model import SGDClassifier
  24. from sklearn.linear_model import Perceptron
  25. from sklearn.linear_model import PassiveAggressiveClassifier
  26. from sklearn.naive_bayes import BernoulliNB, MultinomialNB
  27. from sklearn.neighbors import KNeighborsClassifier
  28. from sklearn.neighbors import NearestCentroid
  29. from sklearn.ensemble import RandomForestClassifier
  30. from sklearn.utils.extmath import density
  31. from sklearn import metrics
  32.  
  33.  
  34. # Display progress logs on stdout
  35. logging.basicConfig(level=logging.INFO,
  36.                     format='%(asctime)s %(levelname)s %(message)s')
  37.  
  38.  
  39. # parse commandline arguments
  40. op = OptionParser()
  41. op.add_option("--report",
  42.               action="store_true", dest="print_report",
  43.               help="Print a detailed classification report.")
  44. op.add_option("--chi2_select",
  45.               action="store", type="int", dest="select_chi2",
  46.               help="Select some number of features using a chi-squared test")
  47. op.add_option("--confusion_matrix",
  48.               action="store_true", dest="print_cm",
  49.               help="Print the confusion matrix.")
  50. op.add_option("--top10",
  51.               action="store_true", dest="print_top10",
  52.               help="Print ten most discriminative terms per class"
  53.                    " for every classifier.")
  54. op.add_option("--all_categories",
  55.               action="store_true", dest="all_categories",
  56.               help="Whether to use all categories or not.")
  57. op.add_option("--use_hashing",
  58.               action="store_true",
  59.               help="Use a hashing vectorizer.")
  60. op.add_option("--n_features",
  61.               action="store", type=int, default=2 ** 16,
  62.               help="n_features when using the hashing vectorizer.")
  63. op.add_option("--filtered",
  64.               action="store_true",
  65.               help="Remove newsgroup information that is easily overfit: "
  66.                    "headers, signatures, and quoting.")
  67.  
  68. (opts, args) = op.parse_args()
  69. if len(args) > 0:
  70.     op.error("this script takes no arguments.")
  71.     sys.exit(1)
  72.  
  73. print(__doc__)
  74. op.print_help()
  75. print()
  76.  
  77.  
  78. ###############################################################################
  79. # Load some categories from the training set
  80. if opts.all_categories:
  81.     categories = None
  82. else:
  83.     categories = [
  84.         'alt.atheism',
  85.         'talk.religion.misc',
  86.         'comp.graphics',
  87.         'sci.space',
  88.     ]
  89.  
  90. if opts.filtered:
  91.     remove = ('headers', 'footers', 'quotes')
  92. else:
  93.     remove = ()
  94.  
  95. print("Loading 20 newsgroups dataset for categories:")
  96. print(categories if categories else "all")
  97.  
  98. data_train = fetch_20newsgroups(subset='train', categories=categories,
  99.                                 shuffle=True, random_state=42,
  100.                                 remove=remove)
  101.  
  102. data_test = fetch_20newsgroups(subset='test', categories=categories,
  103.                                shuffle=True, random_state=42,
  104.                                remove=remove)
  105. print('data loaded')
  106.  
  107. categories = data_train.target_names    # for case categories == None
  108.  
  109.  
  110. def size_mb(docs):
  111.     return sum(len(s.encode('utf-8')) for s in docs) / 1e6
  112.  
  113. data_train_size_mb = size_mb(data_train.data)
  114. data_test_size_mb = size_mb(data_test.data)
  115.  
  116. print("%d documents - %0.3fMB (training set)" % (
  117.     len(data_train.data), data_train_size_mb))
  118. print("%d documents - %0.3fMB (test set)" % (
  119.     len(data_test.data), data_test_size_mb))
  120. print("%d categories" % len(categories))
  121. print()
  122.  
  123. # split a training set and a test set
  124. y_train, y_test = data_train.target, data_test.target
  125.  
  126. print("Extracting features from the training data using a sparse vectorizer")
  127. t0 = time()
  128. if opts.use_hashing:
  129.     vectorizer = HashingVectorizer(stop_words='english', non_negative=True,
  130.                                    n_features=opts.n_features)
  131.     X_train = vectorizer.transform(data_train.data)
  132. else:
  133.     vectorizer = TfidfVectorizer(sublinear_tf=True, max_df=0.5,
  134.                                  stop_words='english')
  135.     X_train = vectorizer.fit_transform(data_train.data)
  136. duration = time() - t0
  137. print("done in %fs at %0.3fMB/s" % (duration, data_train_size_mb / duration))
  138. print("n_samples: %d, n_features: %d" % X_train.shape)
  139. print()
  140.  
  141. print("Extracting features from the test data using the same vectorizer")
  142. t0 = time()
  143. X_test = vectorizer.transform(data_test.data)
  144. duration = time() - t0
  145. print("done in %fs at %0.3fMB/s" % (duration, data_test_size_mb / duration))
  146. print("n_samples: %d, n_features: %d" % X_test.shape)
  147. print()
  148.  
  149. # mapping from integer feature name to original token string
  150. if opts.use_hashing:
  151.     feature_names = None
  152. else:
  153.     feature_names = vectorizer.get_feature_names()
  154.  
  155. if opts.select_chi2:
  156.     print("Extracting %d best features by a chi-squared test" %
  157.           opts.select_chi2)
  158.     t0 = time()
  159.     ch2 = SelectKBest(chi2, k=opts.select_chi2)
  160.     X_train = ch2.fit_transform(X_train, y_train)
  161.     X_test = ch2.transform(X_test)
  162.     if feature_names:
  163.         # keep selected feature names
  164.         feature_names = [feature_names[i] for i
  165.                          in ch2.get_support(indices=True)]
  166.     print("done in %fs" % (time() - t0))
  167.     print()
  168.  
  169. if feature_names:
  170.     feature_names = np.asarray(feature_names)
  171.  
  172.  
  173. def trim(s):
  174.     """Trim string to fit on terminal (assuming 80-column display)"""
  175.     return s if len(s) <= 80 else s[:77] + "..."
  176.  
  177.  
  178. ###############################################################################
  179. # Benchmark classifiers
  180. def benchmark(clf):
  181.     print('_' * 80)
  182.     print("Training: ")
  183.     print(clf)
  184.     t0 = time()
  185.     clf.fit(X_train, y_train)
  186.     train_time = time() - t0
  187.     print("train time: %0.3fs" % train_time)
  188.  
  189.     t0 = time()
  190.     pred = clf.predict(X_test)
  191.     test_time = time() - t0
  192.     print("test time:  %0.3fs" % test_time)
  193.  
  194.     score = metrics.accuracy_score(y_test, pred)
  195.     print("accuracy:   %0.3f" % score)
  196.  
  197.     if hasattr(clf, 'coef_'):
  198.         print("dimensionality: %d" % clf.coef_.shape[1])
  199.         print("density: %f" % density(clf.coef_))
  200.  
  201.         if opts.print_top10 and feature_names is not None:
  202.             print("top 10 keywords per class:")
  203.             for i, category in enumerate(categories):
  204.                 top10 = np.argsort(clf.coef_[i])[-10:]
  205.                 print(trim("%s: %s"
  206.                       % (category, " ".join(feature_names[top10]))))
  207.         print()
  208.  
  209.     if opts.print_report:
  210.         print("classification report:")
  211.         print(metrics.classification_report(y_test, pred,
  212.                                             target_names=categories))
  213.  
  214.     if opts.print_cm:
  215.         print("confusion matrix:")
  216.         print(metrics.confusion_matrix(y_test, pred))
  217.  
  218.     print()
  219.     clf_descr = str(clf).split('(')[0]
  220.     return clf_descr, score, train_time, test_time
  221.  
  222.  
  223. results = []
  224. for clf, name in (
  225.         (RidgeClassifier(tol=1e-2, solver="lsqr"), "Ridge Classifier"),
  226.         (Perceptron(n_iter=50), "Perceptron"),
  227.         (PassiveAggressiveClassifier(n_iter=50), "Passive-Aggressive"),
  228.         (KNeighborsClassifier(n_neighbors=10), "kNN"),
  229.         (RandomForestClassifier(n_estimators=100), "Random forest")):
  230.     print('=' * 80)
  231.     print(name)
  232.     results.append(benchmark(clf))
  233.  
  234. for penalty in ["l2", "l1"]:
  235.     print('=' * 80)
  236.     print("%s penalty" % penalty.upper())
  237.     # Train Liblinear model
  238.     results.append(benchmark(LinearSVC(loss='l2', penalty=penalty,
  239.                                             dual=False, tol=1e-3)))
  240.  
  241.     # Train SGD model
  242.     results.append(benchmark(SGDClassifier(alpha=.0001, n_iter=50,
  243.                                            penalty=penalty)))
  244.  
  245. # Train SGD with Elastic Net penalty
  246. print('=' * 80)
  247. print("Elastic-Net penalty")
  248. results.append(benchmark(SGDClassifier(alpha=.0001, n_iter=50,
  249.                                        penalty="elasticnet")))
  250.  
  251. # Train NearestCentroid without threshold
  252. print('=' * 80)
  253. print("NearestCentroid (aka Rocchio classifier)")
  254. results.append(benchmark(NearestCentroid()))
  255.  
  256. # Train sparse Naive Bayes classifiers
  257. print('=' * 80)
  258. print("Naive Bayes")
  259. results.append(benchmark(MultinomialNB(alpha=.01)))
  260. results.append(benchmark(BernoulliNB(alpha=.01)))
  261.  
  262. print('=' * 80)
  263. print("LinearSVC with L1-based feature selection")
  264. # The smaller C, the stronger the regularization.
  265. # The more regularization, the more sparsity.
  266. results.append(benchmark(Pipeline([
  267.   ('feature_selection', LinearSVC(penalty="l1", dual=False, tol=1e-3)),
  268.   ('classification', LinearSVC())
  269. ])))
  270.  
  271. # make some plots
  272.  
  273. indices = np.arange(len(results))
  274.  
  275. results = [[x[i] for x in results] for i in range(4)]
  276.  
  277. clf_names, score, training_time, test_time = results
  278. training_time = np.array(training_time) / np.max(training_time)
  279. test_time = np.array(test_time) / np.max(test_time)
  280.  
  281. plt.figure(figsize=(12, 8))
  282. plt.title("Score")
  283. plt.barh(indices, score, .2, label="score", color='r')
  284. plt.barh(indices + .3, training_time, .2, label="training time", color='g')
  285. plt.barh(indices + .6, test_time, .2, label="test time", color='b')
  286. plt.yticks(())
  287. plt.legend(loc='best')
  288. plt.subplots_adjust(left=.25)
  289. plt.subplots_adjust(top=.95)
  290. plt.subplots_adjust(bottom=.05)
  291.  
  292. for i, c in zip(indices, clf_names):
  293.     plt.text(-.3, i, c)
  294.  
  295. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement