Advertisement
Track33r

Naive Bayes implementation for assigment

Jul 29th, 2012
262
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 16.60 KB | None | 0 0
  1. // NLP Programming Assignment #3
  2. // NaiveBayes
  3. // 2012
  4.  
  5. //
  6. // Things for you to implement are marked with TODO!
  7. // Generally, you should not need to touch things *not* marked TODO
  8. //
  9. // Remember that when you submit your code, it is not run from the command line
  10. // and your main() will *not* be run. To be safest, restrict your changes to
  11. // addExample() and classify() and anything you further invoke from there.
  12. //
  13.  
  14. import java.util.*;
  15. import java.util.regex.*;
  16. import java.io.*;
  17. import sun.misc.Regexp;
  18.  
  19. public class NaiveBayes {
  20.  
  21.   public static boolean FILTER_STOP_WORDS = false; // this gets set in main()
  22.   private static List<String> stopList = readFile(new File("../data/english.stop"));
  23.  
  24.   class MyIndex
  25.   {
  26.       HashMap<String, Integer> _wordsCounts = new HashMap<String, Integer>();
  27.       int _totalWords = 0;
  28.      
  29.       public void Clear()
  30.       {
  31.          _wordsCounts.clear();
  32.          _totalWords = 0;
  33.       }
  34.      
  35.       public int Vocabulary()
  36.       {
  37.           return _wordsCounts.size();
  38.       }
  39.      
  40.       public int Total()
  41.       {
  42.           return _totalWords;
  43.       }
  44.      
  45.       public int GetWordCount(String word)
  46.       {
  47.           Integer c = _wordsCounts.get(word);
  48.           c = c==null?0:c;
  49.           return  c;
  50.       }      
  51.      
  52.       public Set<String> Words()
  53.       {
  54.           return _wordsCounts.keySet();
  55.       }
  56.      
  57.       void Merge(MyIndex other)
  58.       {
  59.           for(Map.Entry<String, Integer> i: other._wordsCounts.entrySet())
  60.           {
  61.               int c = GetWordCount(i.getKey());
  62.                
  63.               _wordsCounts.put(i.getKey(), c + i.getValue());
  64.           }
  65.          
  66.           _totalWords += other._totalWords;
  67.       }    
  68.       Pattern _filter = Pattern.compile("[\\w]*['-]?[\\w]+");
  69.       public void Add(Collection<String> words)
  70.       {
  71.          
  72.           for(String i: words)
  73.           {
  74.               /*if(!_filter.matcher(i).matches())
  75.               {
  76.                   continue;
  77.               }*/
  78.               _totalWords++;
  79.               int c = GetWordCount(i);
  80.               _wordsCounts.put(i, c + 1);
  81.           }
  82.       }    
  83.      
  84.   }
  85.  
  86.   class MyBayes
  87.   {
  88.       MyIndex _vocabulary;
  89.       MyIndex _tempVocabulary = new MyIndex();;
  90.       MyIndex _self = new MyIndex();
  91.       public int _totalCount = 0;
  92.       public MyBayes(MyIndex vocabulary)
  93.       {
  94.           _vocabulary=vocabulary;
  95.       }
  96.      
  97.        void Add(List<String> words, boolean binary)
  98.        {
  99.            _totalCount+=words.size();
  100.            if(binary)
  101.            {
  102.                _tempVocabulary.Clear();
  103.                _tempVocabulary.Add(words);
  104.                _vocabulary.Merge(_tempVocabulary);
  105.            
  106.                _self.Add(_tempVocabulary.Words());
  107.            }
  108.            else
  109.            {
  110.                _vocabulary.Add(words);
  111.                _self.Add(words);
  112.            }
  113.            
  114.        }
  115.      
  116.        double score(Collection<String> words)
  117.       {
  118.           double scores = 0;
  119.          
  120.           double v = _vocabulary.Vocabulary();
  121.           double t = _self.Total();
  122.          
  123.           for(String word: words)
  124.           {
  125.              
  126.              double score = 0;
  127.              
  128.              double c = _self.GetWordCount(word);
  129.              
  130.              
  131.              score = (c + 1)/(t + v);
  132.              
  133.              scores += Math.log(score);
  134.              
  135.           }
  136.          
  137.           return  scores;
  138.       }
  139.      
  140.   }
  141.  
  142.   //TODO
  143.   /**
  144.    * Put your code for adding information to your NB classifier here
  145.    **/
  146.   MyIndex _index = new MyIndex();
  147.   MyBayes _neg = new MyBayes(_index);
  148.   MyBayes _pos = new MyBayes(_index);
  149.   double _pp = 0;
  150.  
  151.   public List<String> Preprocess(List<String> words)
  152.   {
  153.       String prefix = "";
  154.       List<String> out = new ArrayList<String>(words.size());
  155.       for(String s: words)
  156.       {
  157.           if(".".equals(s) || ",".equals(s))
  158.           {
  159.               prefix = "";
  160.           }
  161.          
  162.           out.add((prefix + s).toLowerCase());
  163.          
  164.           if("not".equals(s)
  165.               ||"didn't".equals(s)
  166.               ||"don't".equals(s))              
  167.           {
  168.               prefix = "NOT_";
  169.           }        
  170.          
  171.       }
  172.       return out;
  173.   }
  174.  
  175.   public  void Test()
  176.   {
  177.       ArrayList<String> arr = new ArrayList<String>();
  178.       //comedy
  179.       //fun, couple, love, love
  180.       arr.add("fun");
  181.       arr.add("couple");
  182.       arr.add("love");
  183.       arr.add("love");
  184.       _pos.Add(arr, false);
  185.       arr.clear();
  186.      
  187.       //couple, fly, fast, fun, fun
  188.       arr.add("couple");
  189.       arr.add("fly");
  190.       arr.add("fast");
  191.       arr.add("fun");
  192.       arr.add("fun");
  193.       _pos.Add(arr, false);
  194.      
  195.       //action
  196.       //fast, furious, shoot
  197.       arr.add("fast");
  198.       arr.add("furious");
  199.       arr.add("shoot");
  200.       _neg.Add(arr, false);
  201.       arr.clear();
  202.       //furious, shoot, shoot, fun
  203.       arr.add("furious");
  204.       arr.add("shoot");
  205.       arr.add("shoot");
  206.       arr.add("fun");
  207.       _neg.Add(arr, false);
  208.       arr.clear();
  209.       //fly, fast, shoot, love
  210.       arr.add("fly");
  211.       arr.add("fast");
  212.       arr.add("shoot");
  213.       arr.add("love");
  214.       _neg.Add(arr, false);
  215.       arr.clear();
  216.      
  217.       //fast, couple, shoot, fly
  218.       arr.add("fun");
  219.       arr.add("couple");
  220.       arr.add("shoot");
  221.       arr.add("fast");
  222.      
  223.       double s1 = Math.log(2.0/5.0) + _pos.score(arr);
  224.       double s2 = Math.log(3.0/5.0) + _neg.score(arr);
  225.       if(s1 > s2)
  226.       {
  227.           System.out.print("comedy");
  228.        
  229.       }
  230.       else
  231.       {
  232.           System.out.print("action");
  233.       }            
  234.   }
  235.  
  236.   public  void Test2()
  237.   {
  238.       boolean binary = true;
  239.       ArrayList<String> arr = new ArrayList<String>();
  240.       //positive
  241.       //"good"  "poor"  "great"
  242.       //3   0   3
  243.       arr.add("good");
  244.       arr.add("good");
  245.       arr.add("good");
  246.       arr.add("great");
  247.       arr.add("great");
  248.       arr.add("great");
  249.       _pos.Add(arr, binary);
  250.       arr.clear();
  251.       //"good"  "poor"  "great"
  252.       //0   1   2      
  253.       arr.add("poor");
  254.       arr.add("great");
  255.       arr.add("great");
  256.       _pos.Add(arr, binary);
  257.       arr.clear();
  258.      
  259.       //negative
  260.       //"good"  "poor"  "great"
  261.       //1   3   0    
  262.       arr.add("good");
  263.       arr.add("poor");
  264.       arr.add("poor");
  265.       arr.add("poor");
  266.       _neg.Add(arr, binary);
  267.       arr.clear();
  268.      
  269.       //negative
  270.       //"good"  "poor"  "great"
  271.       //1   5   2    
  272.       arr.add("good");
  273.       arr.add("poor");
  274.       arr.add("poor");
  275.       arr.add("poor");
  276.       arr.add("poor");
  277.       arr.add("poor");
  278.       arr.add("great");
  279.       arr.add("great");
  280.       _neg.Add(arr, binary);
  281.       arr.clear();
  282.      
  283.       //"good"  "poor"  "great"
  284.       //0   2   0    
  285.       arr.add("poor");
  286.       arr.add("poor");
  287.       _neg.Add(arr, binary);
  288.       arr.clear();
  289.      
  290.       //Good acting, poor plot.
  291.       arr.add("good");
  292.       arr.add("acting");
  293.       arr.add("poor");
  294.       arr.add("plot");
  295.      
  296.       double s1 = Math.log(2.0/5.0) + _pos.score(arr);
  297.       double s2 = Math.log(3.0/5.0) + _neg.score(arr);
  298.       if(s1 > s2)
  299.       {
  300.           System.out.print("pos");
  301.        
  302.       }
  303.       else
  304.       {
  305.           System.out.print("neg");
  306.       }
  307.      
  308.   }
  309.   public void addExample(String klass, List<String> words)
  310.   {
  311.       Test2();
  312.       return;
  313.      /* List<String> processed = Preprocess(words);
  314.     if( "pos".equals(klass))
  315.     {
  316.         _pos.Add(processed, true);
  317.     }
  318.     else
  319.     {
  320.         _neg.Add(processed, true);
  321.     }*/
  322.   }
  323.  
  324.   //TODO
  325.   /**
  326.    *  Put your code here for deciding the class of the input file.
  327.    *  Currently, it just randomly chooses "pos" or "negative"
  328.    */  
  329.   public String classify(List<String> words)
  330.   {
  331.     double pp = (double)_pos._totalCount / (double)(_pos._totalCount + _neg._totalCount);
  332.     double pn = (double)_neg._totalCount / (double)(_pos._totalCount + _neg._totalCount);
  333.     //if( (_pos.score(words) ) > (_neg.score(words)) )
  334.     MyIndex index = new MyIndex();
  335.     index.Add(words);
  336.     if((_pos.score(index.Words()) + Math.log(pp)) > (_neg.score(index.Words())  + Math.log(pn)))
  337.     {
  338.         return "pos";
  339.     }
  340.     else
  341.     {
  342.         return "neg";
  343.     }
  344.   }
  345.  
  346.  
  347.  
  348.  
  349.   public void train(String trainPath) {
  350.     File trainDir = new File(trainPath);
  351.     if (!trainDir.isDirectory()) {
  352.       System.err.println("[ERROR]\tinvalid training directory specified.  ");
  353.     }
  354.  
  355.     TrainSplit split = new TrainSplit();
  356.     for(File dir: trainDir.listFiles()) {
  357.     if(!dir.getName().startsWith(".")) {
  358.         List<File> dirList = Arrays.asList(dir.listFiles());
  359.         for(File f: dirList) {
  360.           split.train.add(f);
  361.         }
  362.     }
  363.     }
  364.     for(File file: split.train) {
  365.       String klass = file.getParentFile().getName();
  366.       List<String> words = readFile(file);
  367.         if (FILTER_STOP_WORDS) {words = filterStopWords(words);}
  368.         addExample(klass,words);
  369.     }
  370.     return;
  371.   }
  372.  
  373.   public List<List<String>> readTest(String ch_aux) {
  374.     List<List<String>> data = new ArrayList<List<String>>();
  375.     String [] docs = ch_aux.split("###");
  376.     TrainSplit split = new TrainSplit();
  377.     for(String doc : docs) {
  378.       List<String> words = segmentWords(doc);
  379.       if (FILTER_STOP_WORDS) {words = filterStopWords(words);}
  380.       data.add(words);
  381.     }
  382.     return data;
  383.   }
  384.  
  385.      
  386.   /**
  387.    * This class holds the list of train and test files for a given CV fold
  388.    * constructed in getFolds()
  389.    **/
  390.   public static class TrainSplit {
  391.     // training files for this split
  392.     List<File> train = new ArrayList<File>();
  393.     // test files for this split;
  394.     List<File> test = new ArrayList<File>();
  395.   }
  396.  
  397.   public static int numFolds = 10;
  398.  
  399.   /**
  400.    * This creates train/test splits for each of the numFold folds.
  401.    **/
  402.   static public List<TrainSplit> getFolds(List<File> files) {
  403.     List<TrainSplit> splits = new ArrayList<TrainSplit>();
  404.    
  405.     for( Integer fold=0; fold<numFolds; fold++ ) {
  406.       TrainSplit split = new TrainSplit();
  407.       for(File file: files) {
  408.         if( file.getName().subSequence(2,3).equals(fold.toString()) ) {
  409.           split.test.add(file);
  410.         } else {
  411.           split.train.add(file);
  412.         }
  413.       }
  414.  
  415.       splits.add(split);
  416.     }
  417.     return splits;
  418.   }
  419.  
  420.   // returns accuracy
  421.   public double evaluate(TrainSplit split) {
  422.     int numCorrect = 0;
  423.     for (File file : split.test) {
  424.       String klass = file.getParentFile().getName();
  425.         List<String> words = readFile(file);
  426.         if (FILTER_STOP_WORDS) {words = filterStopWords(words);}
  427.       String guess = classify(words);
  428.       if(klass.equals(guess)) {
  429.           numCorrect++;
  430.       }
  431.     }
  432.     return ((double)numCorrect)/split.test.size();
  433.   }
  434.  
  435.  
  436.   /**
  437.    * Remove any stop words or punctuation from a list of words.
  438.    **/
  439.   public static List<String> filterStopWords(List<String> words) {
  440.     List<String> filtered = new ArrayList<String>();
  441.     for (String word :words) {
  442.       if (!stopList.contains(word) && !word.matches(".*\\W+.*")) {
  443.     filtered.add(word);
  444.       }
  445.     }
  446.     return filtered;
  447.   }
  448.  
  449.   /**
  450.    * Code for reading a file.  you probably don't want to modify anything here,
  451.    * unless you don't like the way we segment files.
  452.    **/
  453.   private static List<String> readFile(File f) {
  454.     try {
  455.       StringBuilder contents = new StringBuilder();
  456.  
  457.       BufferedReader input = new BufferedReader(new FileReader(f));
  458.       for(String line = input.readLine(); line != null; line = input.readLine()) {
  459.         contents.append(line);
  460.         contents.append("\n");
  461.       }
  462.       input.close();
  463.  
  464.       return segmentWords(contents.toString());
  465.      
  466.     } catch(IOException e) {
  467.       e.printStackTrace();
  468.       System.exit(1);
  469.       return null;
  470.     }
  471.   }
  472.  
  473.   /**
  474.    * Splits lines on whitespace for file reading
  475.    **/
  476.   private static List<String> segmentWords(String s) {
  477.     List<String> ret = new ArrayList<String>();
  478.    
  479.     for(String word:  s.split("\\s")) {
  480.       if(word.length() > 0) {
  481.         ret.add(word);
  482.       }
  483.     }
  484.     return ret;
  485.   }
  486.  
  487.   public List<TrainSplit> getTrainSplits(String trainPath) {
  488.     File trainDir = new File(trainPath);
  489.     if (!trainDir.isDirectory()) {
  490.       System.err.println("[ERROR]\tinvalid training directory specified.  ");
  491.     }
  492.     List<TrainSplit> splits = new ArrayList<TrainSplit>();
  493.     List<File> files = new ArrayList<File>();
  494.     for(File dir: trainDir.listFiles()) {
  495.     if(!dir.getName().startsWith(".")) {
  496.         List<File> dirList = Arrays.asList(dir.listFiles());
  497.         for(File f: dirList) {
  498.           files.add(f);
  499.         }
  500.     }
  501.     }
  502.     splits = getFolds(files);
  503.     return splits;
  504.   }
  505.  
  506.  
  507.   /**
  508.    * build splits according to command line args.  If args.length==1
  509.    * do 10-fold cross validation, if args.length==2 create one TrainSplit
  510.    * with all files from the train_dir and all files from the test_dir
  511.    */
  512.   private static List<TrainSplit> buildSplits(List<String> args) {
  513.     File trainDir = new File(args.get(0));
  514.     if (!trainDir.isDirectory()) {
  515.       System.err.println("[ERROR]\tinvalid training directory specified.  ");
  516.     }
  517.  
  518.     List<TrainSplit> splits = new ArrayList<TrainSplit>();
  519.     if (args.size() == 1) {
  520.       System.out.println("[INFO]\tPerforming 10-fold cross-validation on data set:\t"+args.get(0));
  521.       List<File> files = new ArrayList<File>();
  522.       for(File dir: trainDir.listFiles()) {
  523.     if(!dir.getName().startsWith(".")) {
  524.         List<File> dirList = Arrays.asList(dir.listFiles());
  525.         for(File f: dirList) {
  526.           files.add(f);
  527.         }
  528.     }
  529.       }
  530.       splits = getFolds(files);
  531.     } else if (args.size() == 2) {
  532.       // testing/training on two different data sets is treated like a single fold
  533.       System.out.println("[INFO]\tTraining on data set:\t"+args.get(0)+" testing on data set:\t"+args.get(1));
  534.       TrainSplit split = new TrainSplit();
  535.       for(File dir: trainDir.listFiles()) {
  536.     if(!dir.getName().startsWith(".")) {
  537.         List<File> dirList = Arrays.asList(dir.listFiles());
  538.         for(File f: dirList) {
  539.           split.train.add(f);
  540.         }
  541.     }
  542.       }
  543.       File testDir = new File(args.get(1));
  544.       if (!testDir.isDirectory()) {
  545.     System.err.println("[ERROR]\tinvalid testing directory specified.  ");
  546.       }
  547.       for(File dir: testDir.listFiles()) {
  548.     if(!dir.getName().startsWith(".")) {
  549.         List<File> dirList = Arrays.asList(dir.listFiles());
  550.         for(File f: dirList) {
  551.           split.test.add(f);
  552.         }
  553.     }
  554.       }
  555.       splits.add(split);
  556.     }
  557.     return splits;
  558.   }
  559.  
  560.   public void train(TrainSplit split) {
  561.       for(File file: split.train) {
  562.         String klass = file.getParentFile().getName();
  563.         List<String> words = readFile(file);
  564.     if (FILTER_STOP_WORDS) {words = filterStopWords(words);}
  565.         addExample(klass,words);
  566.       }
  567.   }
  568.  
  569.  
  570.   public static void main(String[] args) {
  571.     List<String> otherArgs = Arrays.asList(args);
  572.     if ( args.length > 0 && args[0].equals("-f") ) {
  573.       FILTER_STOP_WORDS = true;
  574.       otherArgs = otherArgs.subList(1,otherArgs.size());
  575.     }
  576.     if (otherArgs.size() < 1 || otherArgs.size() > 2) {
  577.       System.out.println("[ERROR]\tInvalid number of arguments");
  578.       System.out.println("\tUsage: java -cp [-f] trainDir [testDir]");
  579.       System.out.println("\tWith -f flag implements stop word removal.");
  580.       System.out.println("\tIf testDir is omitted, 10-fold cross validation is used for evaluation");
  581.       return;
  582.     }
  583.     System.out.println("[INFO]\tFILTER_STOP_WORDS="+FILTER_STOP_WORDS);
  584.    
  585.     List<TrainSplit> splits = buildSplits(otherArgs);
  586.     double avgAccuracy = 0.0;
  587.     int fold = 0;
  588.     for(TrainSplit split: splits) {
  589.       NaiveBayes classifier = new NaiveBayes();
  590.       double accuracy = 0.0;
  591.  
  592.       for(File file: split.train) {
  593.         String klass = file.getParentFile().getName();
  594.         List<String> words = readFile(file);
  595.     if (FILTER_STOP_WORDS) {words = filterStopWords(words);}
  596.         classifier.addExample(klass,words);
  597.       }
  598.  
  599.       for (File file : split.test) {
  600.         String klass = file.getParentFile().getName();
  601.     List<String> words = readFile(file);
  602.         if (FILTER_STOP_WORDS) {words = filterStopWords(words);}
  603.         String guess = classifier.classify(words);
  604.         if(klass.equals(guess)) {
  605.       accuracy++;
  606.         }
  607.       }
  608.       accuracy = accuracy/split.test.size();
  609.       avgAccuracy += accuracy;
  610.       System.out.println("[INFO]\tFold " + fold + " Accuracy: " + accuracy);
  611.       fold += 1;
  612.     }
  613.     avgAccuracy = avgAccuracy / numFolds;
  614.     System.out.println("[INFO]\tAccuracy: " + avgAccuracy);
  615.   }
  616. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement