Guest User

Untitled

a guest
Feb 20th, 2017
63
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.80 KB | None | 0 0
  1. import pandas as pd
  2. from ggplot import *
  3. from sklearn.datasets import fetch_20newsgroups
  4. from sklearn.metrics import roc_curve
  5.  
  6. # vectorizer
  7. from sklearn.feature_extraction.text import HashingVectorizer
  8.  
  9. # our classifiers
  10. from sklearn.naive_bayes import BernoulliNB, MultinomialNB
  11. from sklearn.svm import SVC
  12. from sklearn.neighbors import KNeighborsClassifier
  13. from sklearn.ensemble import RandomForestClassifier
  14.  
  15. categories = [
  16. 'alt.atheism',
  17. 'talk.religion.misc',
  18. 'comp.graphics',
  19. 'sci.space'
  20. ]
  21.  
  22. data_train = fetch_20newsgroups(subset='train', categories=categories,
  23. shuffle=True, random_state=42)
  24.  
  25. data_test = fetch_20newsgroups(subset='test', categories=categories,
  26. shuffle=True, random_state=42)
  27.  
  28. categories = data_train.target_names
  29.  
  30. vectorizer = HashingVectorizer(stop_words='english', non_negative=True, n_features=1000)
  31. X_train = vectorizer.fit_transform(data_train.data)
  32. X_test = vectorizer.transform(data_test.data)
  33.  
  34. y_train = data_train.target==0
  35. y_test = data_test.target==0
  36.  
  37.  
  38.  
  39. clfs = [
  40. ("MultinomialNB", MultinomialNB()),
  41. ("BernoulliNB", BernoulliNB()),
  42. ("KNeighborsClassifier", KNeighborsClassifier()),
  43. ("RandomForestClassifier", RandomForestClassifier()),
  44. ("SVM", SVC(probability=True))
  45. ]
  46.  
  47. all_results = None
  48. for name, clf in clfs:
  49. clf.fit(X_train.todense(), y_train)
  50. probs = clf.predict_proba(X_test.todense())[:,1]
  51. fpr, tpr, thresh = roc_curve(y_test, probs)
  52. results = pd.DataFrame({
  53. "name": name,
  54. "fpr": fpr,
  55. "tpr": tpr
  56. })
  57. if all_results is None:
  58. all_results = results
  59. else:
  60. all_results = all_results.append(results)
  61.  
  62. ggplot(aes(x='fpr', y='tpr', color='name'), data=all_results) + \
  63. geom_step() + \
  64. geom_abline(color="black") + \
  65. ggtitle("Text Classification Benchmark on 20 News Groups")
Add Comment
Please, Sign In to add comment