SHARE
TWEET

Untitled

a guest Jul 22nd, 2019 63 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import javafx.util.Pair;
  2. import org.nd4j.linalg.io.ClassPathResource;
  3.  
  4. import java.io.BufferedReader;
  5. import java.io.FileInputStream;
  6. import java.io.IOException;
  7. import java.io.InputStreamReader;
  8. import java.util.*;
  9.  
  10. public class PredictorAgent {
  11.     public static void main(String[] args) {
  12.         PredictorAgent pa = new PredictorAgent();
  13.         pa.train();
  14.         long startTime = System.currentTimeMillis();
  15.         System.out.println(pa.getSuggestions("процент исполнения", 8));
  16.         System.out.println("Прошло "
  17.                 + (new Double(System.currentTimeMillis() - startTime) / 1000) + " секунд.");
  18.     }
  19.  
  20.     private String text = "";
  21.     private static final String QUESTIONS_FILE_NAME = "questions.txt";
  22.     private static final String PAD_TOKEN = " # ";
  23.     private static final int MAX_ITERATIONS = 100;
  24.     private HashMap<Pair<String, String>, HashMap<String, Double>> condFreq = new HashMap<>();
  25.     private HashMap<Pair<String, String>, RandomCollection<String>> suggGenerator = new HashMap<>();
  26.  
  27.     public ArrayList<String> getSuggestions(String request, int numRequested) {
  28.         request = request.toLowerCase();
  29.         if (condFreq.isEmpty()) {
  30.             this.train();
  31.         }
  32.  
  33.         int numRequestTokens = request.split(" ").length;
  34.         if (numRequestTokens < 2) return null;
  35.  
  36.         String first = request.split(" ")[numRequestTokens - 2];
  37.         String second = request.split(" ")[numRequestTokens - 1];
  38.         Pair<String, String> keyPair = new Pair<>(first, second);
  39.         HashMap<String, Double> reqFreq = condFreq.get(keyPair);
  40.  
  41.         HashSet<String> uniqueResults = new HashSet<>();
  42.         ArrayList<String> result = new ArrayList<>();
  43.  
  44.         if (!condFreq.containsKey(keyPair)) return result;
  45.  
  46.         int iterationsNum = 0;
  47.  
  48.         while (uniqueResults.size() < Math.min(numRequested, condFreq.get(keyPair).size())) {
  49.             if (iterationsNum++ > MAX_ITERATIONS) break;
  50.             if (suggGenerator.keySet().contains(keyPair)) {
  51.  
  52.                 String suggestion = suggGenerator.get(keyPair).next();
  53.                 if (!uniqueResults.contains(suggestion)) result.add(suggestion);
  54.  
  55.                 uniqueResults.add(suggestion);
  56.             }
  57.         }
  58.         return result;
  59.     }
  60.  
  61.     public void train() {
  62.         // train n-grams condFreq
  63.  
  64.         // System.out.println("В файле " + QUESTIONS_FILE_NAME + " " + this.text.split(" ").length + " токенов.");
  65.         String tokens[] = this.text.split(" ");
  66.         String first = tokens[0];
  67.         String second = tokens[1];
  68.  
  69.         for (Integer i = 2; i < tokens.length; i++) {
  70.             Pair<String, String> fs = new Pair<>(first, second);
  71.             if (condFreq.containsKey(fs)) {
  72.                 if (condFreq.get(fs).containsKey(tokens[i])) {
  73.                     Double curVal = condFreq.get(fs).get(tokens[i]);
  74.                     condFreq.get(fs).put(tokens[i], curVal + 1.);
  75.                 } else {
  76.                     condFreq.get(fs).putIfAbsent(tokens[i], 1.);
  77.                 }
  78.             } else {
  79.  
  80.                 condFreq.putIfAbsent(fs, new HashMap<>());
  81.                 condFreq.get(fs).put(tokens[i], 1.);
  82.             }
  83.             first = second;
  84.             second = tokens[i];
  85.         }
  86.  
  87.         for (Pair mainKey : condFreq.keySet()) {
  88.             for (String key : condFreq.get(mainKey).keySet()) {
  89.                 if (!suggGenerator.keySet().contains(mainKey)) {
  90.                     suggGenerator.put(mainKey, new RandomCollection());
  91.                 }
  92.                 suggGenerator.get(mainKey).add(condFreq.get(mainKey).get(key), key);
  93.             }
  94.         }
  95.     }
  96.  
  97.  
  98.     public PredictorAgent() {
  99.         try {
  100.             FileInputStream fis = new FileInputStream(
  101.                     new ClassPathResource(QUESTIONS_FILE_NAME).getFile().getPath());
  102.             BufferedReader br = new BufferedReader(new InputStreamReader(fis));
  103.             String question = "";
  104.             while ((question = br.readLine()) != null) {
  105.                 this.text += (question.toLowerCase() + PAD_TOKEN);
  106.             }
  107.         } catch (IOException e) {
  108.             System.out.println(e.getMessage());
  109.         }
  110.     }
  111.  
  112.     public String getText() {
  113.         return text;
  114.     }
  115.  
  116.     private class RandomCollection<E> {
  117.         private final NavigableMap<Double, E> map = new TreeMap<Double, E>();
  118.         private final Random random;
  119.         private double total = 0;
  120.  
  121.         public RandomCollection() {
  122.             this(new Random());
  123.         }
  124.  
  125.         public RandomCollection(Random random) {
  126.             this.random = random;
  127.         }
  128.  
  129.         public RandomCollection<E> add(double weight, E result) {
  130.             if (weight <= 0) return this;
  131.             total += weight;
  132.             map.put(total, result);
  133.             return this;
  134.         }
  135.  
  136.         public E next() {
  137.             double value = random.nextDouble() * total;
  138.             return map.higherEntry(value).getValue();
  139.         }
  140.     }
  141. }
  142.  
  143. //                                      Old but Gold
  144. //        for (Object mainKey : condFreq.keySet()) {
  145. //            int sum = 0;
  146. //            HashMap<String, Double> tempCondFreq = condFreq.get(mainKey);
  147. //            for (String key : tempCondFreq.keySet()) {
  148. //                sum += tempCondFreq.get(key);
  149. //            }
  150. //
  151. //            // transforming frequencies to probabilities
  152. //            for (String key : tempCondFreq.keySet()) {
  153. //                tempCondFreq.put(key, tempCondFreq.get(key) / sum);
  154. //            }
  155. //        }
  156. //        System.out.println("Test query (prob): " + condFreq.get(new Pair<String, String>("расходы", "на")));
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
Not a member of Pastebin yet?
Sign Up, it unlocks many cool features!
 
Top