cyga

bold strings poly hashing, kmp and trie.

Sep 19th, 2021
992
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.42 KB | None | 0 0
  1. class PolyHash():
  2.     def __init__(self, seq: List[int], prime=37, rem=10**9+7):
  3.         n = len(seq)
  4.         self.rem = rem
  5.         self.hashes = [0]*(n+1)
  6.         self.pows = [1]*(n+1)
  7.         for i in range(1, n+1):
  8.             self.pows[i] = self.pows[i-1]*prime % rem
  9.             self.hashes[i] = (self.hashes[i-1]*prime + seq[i-1]) % rem
  10.  
  11.     def hash(self, l: int, r: int):
  12.         ''' [l, r) '''
  13.         return (self.hashes[r] - (self.hashes[l]*self.pows[r-l]) % self.rem) % self.rem
  14.    
  15. def string_hasher(s: string) -> PolyHash:
  16.     return PolyHash([ord(ch)-ord('a') for ch in s])
  17.  
  18. def calc_word_hashes(words: List[str]) -> List[int]:
  19.     word_hashes = [None]*len(words)
  20.     for i, word in enumerate(words):
  21.         word_hasher = string_hasher(word)
  22.         word_hashes[i] = word_hasher.hash(0, len(word))
  23.     return word_hashes
  24.    
  25. def scanline(bolds: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
  26.     START, END = 0, 1
  27.     pois = [(start, START) for start, _ in bolds] + [(end, END) for _, end in bolds]
  28.     pois.sort()
  29.    
  30.     n_opened = 0
  31.     b_start = None
  32.     res = []
  33.     for x, tip in pois:
  34.         if tip == START:
  35.             if n_opened == 0:
  36.                 b_start = x
  37.             n_opened += 1
  38.         else:
  39.             n_opened -= 1
  40.             if n_opened == 0:
  41.                 res.append((b_start, x))
  42.                
  43.     return res
  44.  
  45. def find_all_bolds(words: List[str], s: str) -> List[Tuple[int, int]]:
  46.     word_hashes = calc_word_hashes(words)
  47.     s_hasher = string_hasher(s)
  48.     bolds = []
  49.     for i in range(1, len(s)+1):
  50.         for j, h in enumerate(word_hashes):
  51.             if i >= len(words[j]):
  52.                 if s_hasher.hash(i-len(words[j]), i) == h:
  53.                     bolds.append((i-len(words[j]), i))
  54.     return bolds
  55.    
  56. def prefix_sum(s: str):
  57.     n = len(s)
  58.     pi = [0]*n
  59.     for i in range(1, n):
  60.         j = pi[i-1]
  61.         while j > 0 and s[i] != s[j]:
  62.             j = pi[j-1]
  63.         if s[i] == s[j]:
  64.             j += 1
  65.         pi[i] = j
  66.     return pi
  67.    
  68. def find_all_bolds_kmp(words: List[str], s: str) -> List[Tuple[int, int]]:
  69.     pis = [prefix_sum(word) for word in words]
  70.     js = [0]*len(words)
  71.     bolds = []
  72.     for i, ch in enumerate(s):
  73.         for k, j in enumerate(js):
  74.             while j >= len(words[k]) or (j > 0 and ch != words[k][j]):
  75.                 j = pis[k][j-1]
  76.             if ch == words[k][j]:
  77.                 j += 1
  78.             js[k] = j
  79.            
  80.             if j == len(words[k]):
  81.                 bolds.append((i+1-len(words[k]), i+1))
  82.     return bolds
  83.  
  84. class TrieNode:
  85.     def __init__(self):
  86.         self.children = {}
  87.         self.ends = set()
  88.        
  89. class Trie:
  90.     def __init__(self):
  91.         self.root = TrieNode()
  92.  
  93.     def insert(self, word: str) -> None:
  94.         """ Inserts a word into the trie. """
  95.         p = self.root
  96.         for ch in word:
  97.             if ch not in p.children:
  98.                 p.children[ch] = TrieNode()
  99.             p = p.children[ch]
  100.  
  101.         p.ends.add(len(word))
  102.  
  103. def find_all_bolds_trie(words, s) -> List[Tuple[int, int]]:
  104.     ''' O(sum of words lengths + len(s)*?) '''
  105.     trie = Trie()
  106.     for word in words:
  107.         trie.insert(word)
  108.        
  109.     trie_nodes = []
  110.     bolds = []
  111.     for i, ch in enumerate(s):
  112.         trie_nodes2 = []
  113.         trie_nodes.append(trie.root)
  114.         for trie_node in trie_nodes:
  115.             if ch not in trie_node.children:
  116.                 continue
  117.             trie_node = trie_node.children[ch]
  118.             for length in trie_node.ends:
  119.                 bolds.append((i+1-length, i+1))
  120.             trie_nodes2.append(trie_node)
  121.         trie_nodes = trie_nodes2
  122.            
  123.     return bolds
  124.    
  125. class Solution:
  126.     def boldWords(self, words: List[str], s: str) -> str:
  127.         #bolds = find_all_bolds(words, s)
  128.         #bolds = find_all_bolds_kmp(words, s)
  129.         bolds = find_all_bolds_trie(words, s)
  130.                        
  131.         bolds = scanline(bolds)
  132.         res = []
  133.         j, k = 0, 0
  134.         for i in range(len(s)+1):
  135.             if j < len(bolds) and bolds[j][k] == i:
  136.                 if k == 0:
  137.                     res.append("<b>")
  138.                     k += 1
  139.                 else:
  140.                     res.append("</b>")
  141.                     k = 0
  142.                     j += 1
  143.                
  144.             if i < len(s):
  145.                 res.append(s[i])
  146.                        
  147.         return ''.join(res)
Advertisement
Add Comment
Please, Sign In to add comment