Advertisement
Guest User

Untitled

a guest
Jul 22nd, 2019
101
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.67 KB | None | 0 0
  1. import javafx.util.Pair;
  2. import org.nd4j.linalg.io.ClassPathResource;
  3.  
  4. import java.io.BufferedReader;
  5. import java.io.FileInputStream;
  6. import java.io.IOException;
  7. import java.io.InputStreamReader;
  8. import java.util.*;
  9.  
  10. public class PredictorAgent {
  11. public static void main(String[] args) {
  12. PredictorAgent pa = new PredictorAgent();
  13. pa.train();
  14. long startTime = System.currentTimeMillis();
  15. System.out.println(pa.getSuggestions("процент исполнения", 8));
  16. System.out.println("Прошло "
  17. + (new Double(System.currentTimeMillis() - startTime) / 1000) + " секунд.");
  18. }
  19.  
  20. private String text = "";
  21. private static final String QUESTIONS_FILE_NAME = "questions.txt";
  22. private static final String PAD_TOKEN = " # ";
  23. private static final int MAX_ITERATIONS = 100;
  24. private HashMap<Pair<String, String>, HashMap<String, Double>> condFreq = new HashMap<>();
  25. private HashMap<Pair<String, String>, RandomCollection<String>> suggGenerator = new HashMap<>();
  26.  
  27. public ArrayList<String> getSuggestions(String request, int numRequested) {
  28. request = request.toLowerCase();
  29. if (condFreq.isEmpty()) {
  30. this.train();
  31. }
  32.  
  33. int numRequestTokens = request.split(" ").length;
  34. if (numRequestTokens < 2) return null;
  35.  
  36. String first = request.split(" ")[numRequestTokens - 2];
  37. String second = request.split(" ")[numRequestTokens - 1];
  38. Pair<String, String> keyPair = new Pair<>(first, second);
  39. HashMap<String, Double> reqFreq = condFreq.get(keyPair);
  40.  
  41. HashSet<String> uniqueResults = new HashSet<>();
  42. ArrayList<String> result = new ArrayList<>();
  43.  
  44. if (!condFreq.containsKey(keyPair)) return result;
  45.  
  46. int iterationsNum = 0;
  47.  
  48. while (uniqueResults.size() < Math.min(numRequested, condFreq.get(keyPair).size())) {
  49. if (iterationsNum++ > MAX_ITERATIONS) break;
  50. if (suggGenerator.keySet().contains(keyPair)) {
  51.  
  52. String suggestion = suggGenerator.get(keyPair).next();
  53. if (!uniqueResults.contains(suggestion)) result.add(suggestion);
  54.  
  55. uniqueResults.add(suggestion);
  56. }
  57. }
  58. return result;
  59. }
  60.  
  61. public void train() {
  62. // train n-grams condFreq
  63.  
  64. // System.out.println("В файле " + QUESTIONS_FILE_NAME + " " + this.text.split(" ").length + " токенов.");
  65. String tokens[] = this.text.split(" ");
  66. String first = tokens[0];
  67. String second = tokens[1];
  68.  
  69. for (Integer i = 2; i < tokens.length; i++) {
  70. Pair<String, String> fs = new Pair<>(first, second);
  71. if (condFreq.containsKey(fs)) {
  72. if (condFreq.get(fs).containsKey(tokens[i])) {
  73. Double curVal = condFreq.get(fs).get(tokens[i]);
  74. condFreq.get(fs).put(tokens[i], curVal + 1.);
  75. } else {
  76. condFreq.get(fs).putIfAbsent(tokens[i], 1.);
  77. }
  78. } else {
  79.  
  80. condFreq.putIfAbsent(fs, new HashMap<>());
  81. condFreq.get(fs).put(tokens[i], 1.);
  82. }
  83. first = second;
  84. second = tokens[i];
  85. }
  86.  
  87. for (Pair mainKey : condFreq.keySet()) {
  88. for (String key : condFreq.get(mainKey).keySet()) {
  89. if (!suggGenerator.keySet().contains(mainKey)) {
  90. suggGenerator.put(mainKey, new RandomCollection());
  91. }
  92. suggGenerator.get(mainKey).add(condFreq.get(mainKey).get(key), key);
  93. }
  94. }
  95. }
  96.  
  97.  
  98. public PredictorAgent() {
  99. try {
  100. FileInputStream fis = new FileInputStream(
  101. new ClassPathResource(QUESTIONS_FILE_NAME).getFile().getPath());
  102. BufferedReader br = new BufferedReader(new InputStreamReader(fis));
  103. String question = "";
  104. while ((question = br.readLine()) != null) {
  105. this.text += (question.toLowerCase() + PAD_TOKEN);
  106. }
  107. } catch (IOException e) {
  108. System.out.println(e.getMessage());
  109. }
  110. }
  111.  
  112. public String getText() {
  113. return text;
  114. }
  115.  
  116. private class RandomCollection<E> {
  117. private final NavigableMap<Double, E> map = new TreeMap<Double, E>();
  118. private final Random random;
  119. private double total = 0;
  120.  
  121. public RandomCollection() {
  122. this(new Random());
  123. }
  124.  
  125. public RandomCollection(Random random) {
  126. this.random = random;
  127. }
  128.  
  129. public RandomCollection<E> add(double weight, E result) {
  130. if (weight <= 0) return this;
  131. total += weight;
  132. map.put(total, result);
  133. return this;
  134. }
  135.  
  136. public E next() {
  137. double value = random.nextDouble() * total;
  138. return map.higherEntry(value).getValue();
  139. }
  140. }
  141. }
  142.  
  143. // Old but Gold
  144. // for (Object mainKey : condFreq.keySet()) {
  145. // int sum = 0;
  146. // HashMap<String, Double> tempCondFreq = condFreq.get(mainKey);
  147. // for (String key : tempCondFreq.keySet()) {
  148. // sum += tempCondFreq.get(key);
  149. // }
  150. //
  151. // // transforming frequencies to probabilities
  152. // for (String key : tempCondFreq.keySet()) {
  153. // tempCondFreq.put(key, tempCondFreq.get(key) / sum);
  154. // }
  155. // }
  156. // System.out.println("Test query (prob): " + condFreq.get(new Pair<String, String>("расходы", "на")));
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement