Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import javafx.util.Pair;
- import org.nd4j.linalg.io.ClassPathResource;
- import java.io.BufferedReader;
- import java.io.FileInputStream;
- import java.io.IOException;
- import java.io.InputStreamReader;
- import java.util.*;
- public class PredictorAgent {
- public static void main(String[] args) {
- PredictorAgent pa = new PredictorAgent();
- pa.train();
- long startTime = System.currentTimeMillis();
- System.out.println(pa.getSuggestions("процент исполнения", 8));
- System.out.println("Прошло "
- + (new Double(System.currentTimeMillis() - startTime) / 1000) + " секунд.");
- }
- private String text = "";
- private static final String QUESTIONS_FILE_NAME = "questions.txt";
- private static final String PAD_TOKEN = " # ";
- private static final int MAX_ITERATIONS = 100;
- private HashMap<Pair<String, String>, HashMap<String, Double>> condFreq = new HashMap<>();
- private HashMap<Pair<String, String>, RandomCollection<String>> suggGenerator = new HashMap<>();
- public ArrayList<String> getSuggestions(String request, int numRequested) {
- request = request.toLowerCase();
- if (condFreq.isEmpty()) {
- this.train();
- }
- int numRequestTokens = request.split(" ").length;
- if (numRequestTokens < 2) return null;
- String first = request.split(" ")[numRequestTokens - 2];
- String second = request.split(" ")[numRequestTokens - 1];
- Pair<String, String> keyPair = new Pair<>(first, second);
- HashMap<String, Double> reqFreq = condFreq.get(keyPair);
- HashSet<String> uniqueResults = new HashSet<>();
- ArrayList<String> result = new ArrayList<>();
- if (!condFreq.containsKey(keyPair)) return result;
- int iterationsNum = 0;
- while (uniqueResults.size() < Math.min(numRequested, condFreq.get(keyPair).size())) {
- if (iterationsNum++ > MAX_ITERATIONS) break;
- if (suggGenerator.keySet().contains(keyPair)) {
- String suggestion = suggGenerator.get(keyPair).next();
- if (!uniqueResults.contains(suggestion)) result.add(suggestion);
- uniqueResults.add(suggestion);
- }
- }
- return result;
- }
- public void train() {
- // train n-grams condFreq
- // System.out.println("В файле " + QUESTIONS_FILE_NAME + " " + this.text.split(" ").length + " токенов.");
- String tokens[] = this.text.split(" ");
- String first = tokens[0];
- String second = tokens[1];
- for (Integer i = 2; i < tokens.length; i++) {
- Pair<String, String> fs = new Pair<>(first, second);
- if (condFreq.containsKey(fs)) {
- if (condFreq.get(fs).containsKey(tokens[i])) {
- Double curVal = condFreq.get(fs).get(tokens[i]);
- condFreq.get(fs).put(tokens[i], curVal + 1.);
- } else {
- condFreq.get(fs).putIfAbsent(tokens[i], 1.);
- }
- } else {
- condFreq.putIfAbsent(fs, new HashMap<>());
- condFreq.get(fs).put(tokens[i], 1.);
- }
- first = second;
- second = tokens[i];
- }
- for (Pair mainKey : condFreq.keySet()) {
- for (String key : condFreq.get(mainKey).keySet()) {
- if (!suggGenerator.keySet().contains(mainKey)) {
- suggGenerator.put(mainKey, new RandomCollection());
- }
- suggGenerator.get(mainKey).add(condFreq.get(mainKey).get(key), key);
- }
- }
- }
- public PredictorAgent() {
- try {
- FileInputStream fis = new FileInputStream(
- new ClassPathResource(QUESTIONS_FILE_NAME).getFile().getPath());
- BufferedReader br = new BufferedReader(new InputStreamReader(fis));
- String question = "";
- while ((question = br.readLine()) != null) {
- this.text += (question.toLowerCase() + PAD_TOKEN);
- }
- } catch (IOException e) {
- System.out.println(e.getMessage());
- }
- }
- public String getText() {
- return text;
- }
- private class RandomCollection<E> {
- private final NavigableMap<Double, E> map = new TreeMap<Double, E>();
- private final Random random;
- private double total = 0;
- public RandomCollection() {
- this(new Random());
- }
- public RandomCollection(Random random) {
- this.random = random;
- }
- public RandomCollection<E> add(double weight, E result) {
- if (weight <= 0) return this;
- total += weight;
- map.put(total, result);
- return this;
- }
- public E next() {
- double value = random.nextDouble() * total;
- return map.higherEntry(value).getValue();
- }
- }
- }
- // Old but Gold
- // for (Object mainKey : condFreq.keySet()) {
- // int sum = 0;
- // HashMap<String, Double> tempCondFreq = condFreq.get(mainKey);
- // for (String key : tempCondFreq.keySet()) {
- // sum += tempCondFreq.get(key);
- // }
- //
- // // transforming frequencies to probabilities
- // for (String key : tempCondFreq.keySet()) {
- // tempCondFreq.put(key, tempCondFreq.get(key) / sum);
- // }
- // }
- // System.out.println("Test query (prob): " + condFreq.get(new Pair<String, String>("расходы", "на")));
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement