Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package ru.liboskat.task6;
- import com.fasterxml.jackson.annotation.JsonProperty;
- import com.fasterxml.jackson.dataformat.xml.XmlMapper;
- import com.fasterxml.jackson.dataformat.xml.annotation.JacksonXmlElementWrapper;
- import com.fasterxml.jackson.dataformat.xml.annotation.JacksonXmlProperty;
- import com.fasterxml.jackson.dataformat.xml.annotation.JacksonXmlRootElement;
- import org.apache.commons.math3.linear.*;
- import ru.stachek66.nlp.mystem.holding.Factory;
- import ru.stachek66.nlp.mystem.holding.MyStem;
- import ru.stachek66.nlp.mystem.holding.MyStemApplicationException;
- import ru.stachek66.nlp.mystem.holding.Request;
- import ru.stachek66.nlp.mystem.model.Info;
- import scala.Option;
- import scala.collection.JavaConversions;
- import java.io.BufferedReader;
- import java.io.File;
- import java.io.IOException;
- import java.io.InputStreamReader;
- import java.nio.file.Files;
- import java.nio.file.Paths;
- import java.util.*;
- import java.util.function.Function;
- import java.util.stream.Collectors;
- public class VectorSearchUtil {
- public static final int K = 20;
- private static final XmlMapper mapper = new XmlMapper();
- private final static MyStem mystemAnalyzer = new Factory("-igd --eng-gr --format json --weight")
- .newMyStem("3.0", Option.empty()).get();
- public static void main(String[] args) throws IOException {
- List<Term> terms = mapper.readValue(new File("src/ru/liboskat/task6/inverted_index.xml"), Terms.class)
- .getTerm();
- List<Article> documents = mapper.readValue(new File("example.xml"), Documents.class).getDocument();
- int termCount = terms.size();
- int documentsCount = documents.size();
- double[][] matrix = new double[termCount][documentsCount];
- for (int i = 0; i < termCount; i++) {
- int termId = i;
- terms.get(termId).getDoc().forEach(doc -> matrix[termId][doc.getId() - 1] = doc.getTfIdf());
- }
- RealMatrix realMatrix = MatrixUtils.createRealMatrix(matrix);
- SingularValueDecomposition decomposition = new SingularValueDecomposition(realMatrix);
- RealMatrix uk = decomposition.getU().getSubMatrix(0, matrix.length - 1, 0, K - 1);
- RealMatrix sk = decomposition.getS().getSubMatrix(0, K - 1, 0, K - 1);
- RealMatrix skInverse = MatrixUtils.inverse(sk);
- List<SearchResult> searchResults = readQueries().stream()
- .map(query -> getSearchResult(query, terms, documents, realMatrix, uk, skInverse))
- .collect(Collectors.toList());
- Files.write(Paths.get("src/ru/liboskat/task6/vector_search_result.txt"),
- searchResults.stream().map(SearchResult::toString).collect(Collectors.toList()));
- }
- private static SearchResult getSearchResult(String query, List<Term> terms, List<Article> documents,
- RealMatrix realMatrix, RealMatrix uk, RealMatrix skInverse) {
- RealMatrix q = getSimpleQueryVector(query, terms).multiply(uk).multiply(skInverse);
- List<ArticleSearchResult> articleSearchResults = new ArrayList<>();
- for (int j = 0; j < documents.size(); j++) {
- RealMatrix d = realMatrix.getColumnMatrix(j).transpose().multiply(uk).multiply(skInverse);
- Article doc = documents.get(j);
- articleSearchResults.add(
- new ArticleSearchResult(doc.getId(), doc.getTitle(), doc.getUrl(), getSimilarity(q, d)));
- }
- return new SearchResult(query, articleSearchResults.stream()
- .sorted(Comparator.comparing(ArticleSearchResult::getScore).reversed())
- .limit(10)
- .collect(Collectors.toList()));
- }
- private static double getSimilarity(RealMatrix q, RealMatrix d) {
- double vectorsProduct = 0.0;
- double documentVectorSumOfSquares = 0.0;
- double queryVectorSumOfSquares = 0;
- for (int i = 0; i < K; i++) {
- double documentVectorElement = d.getEntry(0, i);
- double queryVectorElement = q.getEntry(0, i);
- vectorsProduct += documentVectorElement * queryVectorElement;
- documentVectorSumOfSquares += documentVectorElement * documentVectorElement;
- queryVectorSumOfSquares += queryVectorElement * queryVectorElement;
- }
- if (documentVectorSumOfSquares == 0 || queryVectorSumOfSquares == 0) {
- return 0;
- } else {
- return vectorsProduct / (Math.sqrt(documentVectorSumOfSquares) * Math.sqrt(queryVectorSumOfSquares));
- }
- }
- private static RealMatrix getSimpleQueryVector(String query, List<Term> terms) {
- List<String> queryParts = Arrays.stream(query.split(" "))
- .peek(VectorSearchUtil::stem)
- .collect(Collectors.toList());
- double[] queryVector = new double[terms.size()];
- for (int i = 0; i < terms.size(); i++) {
- queryVector[i] = queryParts.contains(terms.get(i).getValue()) ? 1 : 0;
- }
- return MatrixUtils.createRowRealMatrix(queryVector);
- }
- private static String stem(String word) {
- try {
- List<Info> result =
- JavaConversions.seqAsJavaList(
- mystemAnalyzer
- .analyze(Request.apply(word))
- .info()
- .toSeq());
- if (!result.isEmpty() && result.get(0).lex().nonEmpty()) {
- return result.get(0).lex().get();
- }
- return word;
- } catch (MyStemApplicationException e) {
- throw new IllegalArgumentException(e);
- }
- }
- private static List<String> readQueries() throws IOException {
- BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
- System.out.println("Введите число запросов");
- int numberOfQueries = Integer.parseInt(br.readLine());
- List<String> queries = new ArrayList<>();
- System.out.println("Введите " + numberOfQueries + " запросов");
- for (int i = 0; i < numberOfQueries; i++) {
- queries.add(br.readLine());
- }
- br.close();
- return queries;
- }
- private static class SearchResult {
- private String query;
- private List<ArticleSearchResult> found;
- public SearchResult(String query, List<ArticleSearchResult> found) {
- this.query = query;
- this.found = found;
- }
- @Override
- public String toString() {
- StringBuilder stringBuilder = new StringBuilder();
- stringBuilder.append("Запрос='");
- stringBuilder.append(query);
- stringBuilder.append("',\nНайдено=");
- if (found.size() == 0) {
- stringBuilder.append("ничего");
- } else {
- stringBuilder.append("{\n");
- found.forEach(articleSearchResult -> {
- stringBuilder.append(" ");
- stringBuilder.append(articleSearchResult);
- stringBuilder.append(",\n");
- });
- stringBuilder.append("},");
- }
- stringBuilder.append("\n");
- return stringBuilder.toString();
- }
- }
- private static class ArticleSearchResult {
- private Integer id;
- private String title;
- private String url;
- private Double score;
- public ArticleSearchResult(Integer id, String title, String url, Double score) {
- this.id = id;
- this.title = title;
- this.url = url;
- this.score = score;
- }
- public Integer getId() {
- return id;
- }
- public Double getScore() {
- return score;
- }
- @Override
- public String toString() {
- return "Статья{" +
- "url='" + url + '\'' +
- ", score='" + score + '\'' +
- '}';
- }
- }
- @JacksonXmlRootElement(localName = "terms")
- private static class Terms {
- @JacksonXmlElementWrapper(useWrapping = false)
- private List<Term> term;
- public Terms() {
- }
- public Terms(List<Term> term) {
- this.term = term;
- }
- public List<Term> getTerm() {
- return term;
- }
- public void setTerm(List<Term> term) {
- this.term = term;
- }
- }
- private static class Term {
- @JacksonXmlProperty(isAttribute = true)
- private String value;
- @JacksonXmlElementWrapper(useWrapping = false)
- private List<Document> doc;
- public Term() {
- }
- public Term(String value, List<Document> doc) {
- this.value = value;
- this.doc = doc;
- }
- public Optional<Document> findDocument(int documentId) {
- return doc.stream().filter(document -> document.getId().equals(documentId)).findAny();
- }
- public String getValue() {
- return value;
- }
- public void setValue(String value) {
- this.value = value;
- }
- public List<Document> getDoc() {
- return doc;
- }
- public void setDoc(List<Document> doc) {
- this.doc = doc;
- }
- }
- private static class Document implements Comparable<Document> {
- @JacksonXmlProperty(isAttribute = true)
- private Integer id;
- @JacksonXmlProperty(isAttribute = true)
- private Long count;
- @JacksonXmlProperty(isAttribute = true)
- private Double tfIdf;
- public Document() {
- }
- public Document(Integer id, Long count) {
- this.id = id;
- this.count = count;
- }
- public Integer getId() {
- return id;
- }
- public void setId(Integer id) {
- this.id = id;
- }
- public Long getCount() {
- return count;
- }
- public void setCount(Long count) {
- this.count = count;
- }
- @JsonProperty("tf-idf")
- public Double getTfIdf() {
- return tfIdf;
- }
- public void setTfIdf(Double tfIdf) {
- this.tfIdf = tfIdf;
- }
- @Override
- public int compareTo(Document o) {
- return Comparator
- .comparing(Document::getCount)
- .reversed()
- .thenComparing(Document::getId)
- .compare(this, o);
- }
- }
- @JacksonXmlRootElement(localName = "documents")
- private static class Documents {
- @JacksonXmlElementWrapper(useWrapping = false)
- private List<Article> document;
- public Documents() {
- }
- public Documents(List<Article> document) {
- this.document = document;
- }
- public Article findArticle(int id) {
- return getDocument().stream()
- .filter(article -> id == article.getId())
- .findAny().orElseThrow(IllegalArgumentException::new);
- }
- public List<Article> getDocument() {
- return document;
- }
- public void setDocument(List<Article> document) {
- this.document = document;
- }
- }
- private static class Article {
- @JacksonXmlProperty(isAttribute = true)
- private Integer id;
- private String url;
- private String title;
- private String text;
- private String keywords;
- public Article() {
- }
- public Integer getId() {
- return id;
- }
- public void setId(Integer id) {
- this.id = id;
- }
- public String getUrl() {
- return url;
- }
- public void setUrl(String url) {
- this.url = url;
- }
- public String getTitle() {
- return title;
- }
- public void setTitle(String title) {
- this.title = title;
- }
- public String getText() {
- return text;
- }
- public void setText(String text) {
- this.text = text;
- }
- public String getKeywords() {
- return keywords;
- }
- public void setKeywords(String keywords) {
- this.keywords = keywords;
- }
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement