Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import time
- from dataclasses import dataclass
- from typing import List
- import numpy as np
- np.random.seed(0)
- @dataclass
- class WrappedArray:
- arr: np.ndarray
- def __matmul__(self, other) -> int | float:
- return (self.arr @ other.arr).item()
- def get_naive_method_mat(fore_states: List[WrappedArray], post_states: List[WrappedArray]) -> np.ndarray:
- return np.array(
- [
- [
- fore_state @ post_state
- for post_state in post_states
- ]
- for fore_state in fore_states
- ]
- )
- def get_np_fromiter_method_mat(fore_states: List[WrappedArray], post_states: List[WrappedArray]) -> np.ndarray:
- cartesian_prod = np.array(
- np.meshgrid(fore_states, post_states)
- ).T.reshape(-1, 2)
- iterable = (
- fore_state @ post_state for fore_state, post_state in cartesian_prod[:, ]
- )
- return np.fromiter(iterable, dtype=float).reshape(N, N)
- def get_arraize_method_mat(fore_states: List[WrappedArray], post_states: List[WrappedArray]) -> np.ndarray:
- def to_array(states: List[WrappedArray]) -> np.ndarray:
- return np.array([
- wrappedArray.arr.flatten()
- for wrappedArray in states
- ])
- fore_array = to_array(fore_states)
- post_array = to_array(post_states)
- return fore_array @ post_array.T
- if __name__ == '__main__':
- # 準備資料
- N = 1000
- fore_states = [
- WrappedArray(np.random.randint(0, 100, (1, 2)))
- for _ in range(N)
- ]
- post_states = [
- WrappedArray(np.random.randint(0, 100, (2, 1)))
- for _ in range(N)
- ]
- # 開始執行
- t0 = time.time()
- naive_method_mat = get_naive_method_mat(
- fore_states, post_states
- )
- t1 = time.time()
- print(f'naive_method_mat takes {t1 - t0:.2f} seconds')
- np_fromiter_method_mat = get_np_fromiter_method_mat(
- fore_states, post_states
- )
- t2 = time.time()
- print(f'np_fromiter_method_mat takes {t2 - t1:.2f} seconds')
- arraize_method_mat = get_arraize_method_mat(
- fore_states, post_states
- )
- t3 = time.time()
- print(f'arraize_method_mat takes {t3 - t2:.2f} seconds')
- assert np.array_equal(naive_method_mat, np_fromiter_method_mat)
- assert np.array_equal(naive_method_mat, arraize_method_mat)
- """ output
- naive_method_mat takes 1.24 seconds
- np_fromiter_method_mat takes 2.21 seconds
- arraize_method_mat takes 0.02 seconds
- """
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement