import time import numpy as np def softmax(x, axis): e_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) return e_x / np.sum(e_x, axis=axis, keepdims=True) def attention(X, W_q, W_k, W_v): d_k = W_k.shape[1] Q = X @ W_q K = X @ W_k V = X @ W_v scores = Q @ K.T / np.sqrt(d_k) attention_weights = softmax(scores, axis=-1) return attention_weights @ V def multi_head_attention(X, W_q, W_k, W_v, W_o): projected = [] seq_len = X.shape[0] n_head = W_q.shape[0] for n in range(n_head): output = attention(X, W_q[n], W_k[n], W_v[n]) my_proj = output @ W_o[n] projected.append(my_proj) projected = np.array(projected) output = [] for i in range(seq_len): my_output = projected[:,i,:].ravel() output.append(my_output) return np.array(output) def multi_head_attention_vectorized(X, W_q, W_k, W_v, W_o): d_k = W_k.shape[-1] seq_len = X.shape[0] Q = np.einsum('si,hij->hsj', X, W_q) K = np.einsum('si,hik->hsk', X, W_k) V = np.einsum('si,hiv->hsv', X, W_v) scores = Q @ K.transpose(0, 2, 1) / np.sqrt(d_k) weights = softmax(scores, axis=-1) output = weights @ V projected = np.einsum('hsv,hvd->hsd', output, W_o) return projected.transpose(1, 0, 2).reshape(seq_len, W_v.shape[2]) def multi_head_attention_golfed(X, W_q, W_k, W_v, W_o, optimize="optimal"): scores = np.einsum("si,hij,tm,hmj->hst", X, W_q, X, W_k, optimize=optimize) weights = softmax(W_k.shape[-1]**-0.5 * scores, axis=-1) projected = np.einsum('hst,ti,hiv,hvd->shd', weights, X, W_v, W_o, optimize=optimize) return projected.reshape(X.shape[0], W_v.shape[2]) def main(): input_dim = 256 seq_len = 128 d_k = 64 d_v = input_dim n_head = 8 X = np.random.randn(seq_len, input_dim) assert input_dim % n_head == 0 for _ in range(10): X = np.random.randn(seq_len, input_dim) W_q = np.random.randn(n_head, input_dim, d_k) W_k = np.random.randn(n_head, input_dim, d_k) W_v = np.random.randn(n_head, input_dim, d_v) W_o = np.random.randn(n_head, d_v, input_dim // n_head) t0 = time.perf_counter() y = multi_head_attention(X, W_q, W_k, W_v, W_o) t1 = time.perf_counter() y2 = multi_head_attention_vectorized(X, W_q, W_k, W_v, W_o) t2 = time.perf_counter() y3 = multi_head_attention_golfed(X, W_q, W_k, W_v, W_o) t3 = time.perf_counter() print(f"{(t1 - t0) * 1000:8.3f} ms naive multi-head attention") print(f"{(t2 - t1) * 1000:8.3f} ms vectorized multi-head attention") print(f"{(t3 - t2) * 1000:8.3f} ms golfed multi-head attention") print() assert np.allclose(y, y2) assert np.allclose(y, y3) if __name__ == "__main__": main()