Advertisement
kosievdmerwe

1220

Aug 6th, 2022
67
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.84 KB | None | 0 0
  1. MOD = 10**9 + 7
  2.  
  3. VOWELS = "aeiou"
  4.  
  5. ALLOWED_NEXT = {
  6.     "a": "e",
  7.     "e": "ai",
  8.     "i": "aeou",
  9.     "o": "iu",
  10.     "u": "a",
  11. }
  12. ALLOWED_NEXT_IDXES = {
  13.     VOWELS.find(k): [VOWELS.find(v) for v in vs]
  14.     for k, vs in ALLOWED_NEXT.items()
  15. }
  16.  
  17. def gen_update_matrix() -> List[List[int]]:
  18.     ans = [[0] * len(VOWELS) for _ in range(len(VOWELS))]
  19.    
  20.     for k, vs in ALLOWED_NEXT_IDXES.items():
  21.         for v in vs:
  22.             ans[v][k] = 1
  23.    
  24.     return ans
  25.    
  26.  
  27. def matrix_mult(a: List[List[int]], b: List[List[int]]) -> List[List[int]]:
  28.     a_rows = len(a)
  29.     a_cols = len(a[0])
  30.     b_rows = len(b)
  31.     b_cols = len(b[0])
  32.     assert a_cols == b_rows, f"{a_rows}x{a_cols} {b_rows}x{b_cols}"
  33.    
  34.     ans = [[0] * b_cols for _ in range(a_rows)]
  35.     for cr in range(a_rows):
  36.         for cb in range(b_cols):
  37.             for i in range(a_cols):
  38.                 ans[cr][cb] += a[cr][i] * b[i][cb]
  39.                 ans[cr][cb] %= MOD
  40.     return ans
  41.  
  42.  
  43. def matrix_pow_slow(m: List[List[int]], p: int) -> List[List[int]]:
  44.     ans = m
  45.     for i in range(p - 1):
  46.         ans = matrix_mult(ans, m)
  47.     return ans
  48.  
  49.  
  50. def matrix_pow(m: List[List[int]], p: int) -> List[List[int]]:
  51.     assert len(m) == len(m[0])
  52.     ans = [[0] * len(m) for _ in range(len(m))]
  53.     for i in range(len(m)):
  54.         ans[i][i] = 1
  55.    
  56.     cur_mult = m
  57.     cur_pow = 1
  58.     while cur_pow <= p:
  59.         if p & cur_pow:
  60.             ans = matrix_mult(ans, cur_mult)
  61.        
  62.         cur_pow *= 2
  63.         cur_mult = matrix_mult(cur_mult, cur_mult)
  64.     return ans
  65.  
  66.  
  67. class Solution:
  68.     def countVowelPermutation(self, n: int) -> int:
  69.         if n == 1:
  70.             return len(VOWELS)
  71.        
  72.         initial = [[1] * len(VOWELS)]
  73.         m = matrix_pow(gen_update_matrix(), n - 1)
  74.         return sum(matrix_mult(initial, m)[0]) % MOD
  75.        
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement