Advertisement
papple23g

230223 卡式乘積效能問題

Feb 22nd, 2023
884
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.50 KB | None | 0 0
  1.  
  2. import time
  3. from dataclasses import dataclass
  4. from typing import List
  5.  
  6. import numpy as np
  7.  
  8. np.random.seed(0)
  9.  
  10.  
  11. @dataclass
  12. class WrappedArray:
  13.     arr: np.ndarray
  14.  
  15.     def __matmul__(self, other) -> int | float:
  16.         return (self.arr @ other.arr).item()
  17.  
  18.  
  19. def get_naive_method_mat(fore_states: List[WrappedArray], post_states: List[WrappedArray]) -> np.ndarray:
  20.     return np.array(
  21.         [
  22.             [
  23.                 fore_state @ post_state
  24.                 for post_state in post_states
  25.             ]
  26.             for fore_state in fore_states
  27.         ]
  28.     )
  29.  
  30.  
  31. def get_np_fromiter_method_mat(fore_states: List[WrappedArray], post_states: List[WrappedArray]) -> np.ndarray:
  32.     cartesian_prod = np.array(
  33.         np.meshgrid(fore_states, post_states)
  34.     ).T.reshape(-1, 2)
  35.     iterable = (
  36.         fore_state @ post_state for fore_state, post_state in cartesian_prod[:, ]
  37.     )
  38.     return np.fromiter(iterable, dtype=float).reshape(N, N)
  39.  
  40.  
  41. def get_arraize_method_mat(fore_states: List[WrappedArray], post_states: List[WrappedArray]) -> np.ndarray:
  42.  
  43.     def to_array(states: List[WrappedArray]) -> np.ndarray:
  44.         return np.array([
  45.             wrappedArray.arr.flatten()
  46.             for wrappedArray in states
  47.         ])
  48.  
  49.     fore_array = to_array(fore_states)
  50.     post_array = to_array(post_states)
  51.     return fore_array @ post_array.T
  52.  
  53.  
  54. if __name__ == '__main__':
  55.  
  56.     # 準備資料
  57.     N = 1000
  58.     fore_states = [
  59.         WrappedArray(np.random.randint(0, 100, (1, 2)))
  60.         for _ in range(N)
  61.     ]
  62.     post_states = [
  63.         WrappedArray(np.random.randint(0, 100, (2, 1)))
  64.         for _ in range(N)
  65.     ]
  66.  
  67.     # 開始執行
  68.     t0 = time.time()
  69.     naive_method_mat = get_naive_method_mat(
  70.         fore_states, post_states
  71.     )
  72.     t1 = time.time()
  73.     print(f'naive_method_mat takes {t1 - t0:.2f} seconds')
  74.  
  75.     np_fromiter_method_mat = get_np_fromiter_method_mat(
  76.         fore_states, post_states
  77.     )
  78.     t2 = time.time()
  79.     print(f'np_fromiter_method_mat takes {t2 - t1:.2f} seconds')
  80.  
  81.     arraize_method_mat = get_arraize_method_mat(
  82.         fore_states, post_states
  83.     )
  84.     t3 = time.time()
  85.     print(f'arraize_method_mat takes {t3 - t2:.2f} seconds')
  86.  
  87.     assert np.array_equal(naive_method_mat, np_fromiter_method_mat)
  88.     assert np.array_equal(naive_method_mat, arraize_method_mat)
  89.  
  90.     """ output
  91.    naive_method_mat takes 1.24 seconds
  92.    np_fromiter_method_mat takes 2.21 seconds
  93.    arraize_method_mat takes 0.02 seconds
  94.    """
  95.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement