Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #pragma once
- #include <array>
- #include <vector>
- #include <memory>
- #include <string>
- #include <queue>
- class Aho {
- private:
- static constexpr int kMaxSymbols = 'z' - 'a' + 1;
- public:
- struct Node {
- Node(Node* parent, char lastSymbol)
- : parent(parent)
- , lastSymbol(lastSymbol) {}
- std::array<std::unique_ptr<Node>, kMaxSymbols> transitions;
- Node* const parent; // nullptr for root.
- const char lastSymbol;
- Node* suffixLink = nullptr;
- bool isTerminal = false;
- int reachableTerminals = 0;
- };
- public:
- Aho(const std::vector<std::string>& patterns);
- Aho() = delete;
- Aho(const Aho&) = delete;
- Aho(Aho&&) = default;
- Aho& operator=(const Aho&) = delete;
- Aho& operator=(Aho&&) = default;
- int numEntries(const std::string&) const noexcept;
- private:
- void addPattern(const std::string&);
- void calcSubtreeSuffixLinks(Node&) noexcept;
- Node* calcSuffixLink(Node&) noexcept;
- void calcReachableTerminals();
- Node* advance(Node*, char) const noexcept;
- static int ord(char ch) {
- using namespace std::string_literals;
- if (!('a' <= ch && ch <= 'z')) {
- throw std::out_of_range("Symbol '"s + ch + "' is not a legal symbol");
- }
- return ch - 'a';
- }
- private:
- const std::unique_ptr<Node> root;
- };
- Aho::Aho(const std::vector<std::string>& patterns)
- : root(std::make_unique<Node>(nullptr, '\0'))
- {
- for (const auto& pattern : patterns) {
- if (pattern == "") {
- throw std::logic_error("Empty strings are not allowed");
- }
- addPattern(pattern);
- }
- root->suffixLink = root.get();
- calcSubtreeSuffixLinks(*root);
- calcReachableTerminals();
- }
- void Aho::addPattern(const std::string& pattern) {
- Node* ptr = root.get();
- for (const auto& ch : pattern) {
- const auto index = ord(ch);
- if (!ptr->transitions[index]) {
- ptr->transitions[index] = std::make_unique<Node>(ptr, ch);
- }
- ptr = ptr->transitions[index].get();
- }
- ptr->isTerminal = true;
- }
- void Aho::calcSubtreeSuffixLinks(Node& node) noexcept {
- calcSuffixLink(node);
- for (const auto& child : node.transitions) {
- if (child) {
- calcSubtreeSuffixLinks(*child);
- }
- }
- }
- auto Aho::calcSuffixLink(Node& node) noexcept -> Node* {
- if (node.suffixLink) {
- return node.suffixLink;
- }
- Node* ptr = node.parent;
- const auto index = ord(node.lastSymbol);
- while (ptr != root.get()) {
- ptr = calcSuffixLink(*ptr);
- if (ptr->transitions[index]) {
- return node.suffixLink = ptr->transitions[index].get();
- }
- }
- return node.suffixLink = root.get();
- }
- void Aho::calcReachableTerminals() {
- std::queue<Node*> queue;
- queue.push(root.get());
- root->reachableTerminals = 0;
- while (!queue.empty()) {
- const auto p = queue.front();
- queue.pop();
- p->reachableTerminals = p->suffixLink->reachableTerminals;
- if (p->isTerminal) {
- ++p->reachableTerminals;
- }
- for (const auto& p : p->transitions) {
- if (p) {
- queue.push(p.get());
- }
- }
- }
- }
- auto Aho::advance(Node* p, char ch) const noexcept -> Node* {
- const auto index = ord(ch);
- if (p->transitions[index]) {
- return p->transitions[index].get();
- }
- if (p == root.get()) {
- return root.get();
- }
- return advance(p->suffixLink, ch);
- }
- int Aho::numEntries(const std::string& s) const noexcept {
- int result = 0;
- Node* ptr = root.get();
- for (char ch : s) {
- ptr = advance(ptr, ch);
- result += ptr->reachableTerminals;
- }
- return result;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement