hedgefund

rust_matmul_SIMD

Jan 8th, 2025
44
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Rust 3.33 KB | Source Code | 0 0
  1. use core::arch::x86_64::{
  2.     _mm256_add_ps, _mm256_loadu_ps, _mm256_mul_ps, _mm256_set1_ps, _mm256_setzero_ps,
  3.     _mm256_storeu_ps,
  4. };
  5. use rand::Rng;
  6. use std::time::Instant;
  7.  
  8. const N: usize = 1024; // Size of the matrix (N x N)
  9. const SIMD_WIDTH: usize = 8; // Number of elements processed at a time (AVX2)
  10.  
  11. // Function to initialize a matrix with random values
  12. fn initialize_matrix(matrix: &mut [f32], size: usize) {
  13.     let mut rng = rand::thread_rng();
  14.     for i in 0..size {
  15.         for j in 0..size {
  16.             matrix[i * size + j] = rng.gen::<f32>();
  17.         }
  18.     }
  19. }
  20.  
  21. // Function to perform matrix multiplication using SIMD
  22. unsafe fn matrix_multiply_simd(a: &[f32], b: &[f32], c: &mut [f32], size: usize) {
  23.     for i in 0..size {
  24.         for j in (0..size).step_by(SIMD_WIDTH) {
  25.             let mut c_simd = _mm256_setzero_ps();
  26.  
  27.             for k in 0..size {
  28.                 let a_val = _mm256_set1_ps(a[i * size + k]);
  29.                 let b_simd = _mm256_loadu_ps(b.as_ptr().add(k * size + j));
  30.                 c_simd = _mm256_add_ps(c_simd, _mm256_mul_ps(a_val, b_simd));
  31.             }
  32.  
  33.             _mm256_storeu_ps(c.as_mut_ptr().add(i * size + j), c_simd);
  34.         }
  35.     }
  36. }
  37.  
  38. // Function to perform matrix multiplication without SIMD (for comparison)
  39. //#[allow(dead_code)]
  40. //fn matrix_multiply_scalar(a: &[f32], b: &[f32], c: &mut [f32], size: usize) {
  41. //    for i in 0..size {
  42. //        for j in 0..size {
  43. //            let mut sum = 0.0;
  44. //            for k in 0..size {
  45. //                sum += a[i * size + k] * b[k * size + j];
  46. //            }
  47. //            c[i * size + j] = sum;
  48. //        }
  49. //    }
  50. //}
  51.  
  52. fn main() {
  53.     // Allocate memory for matrices
  54.     let mut a = vec![0.0; N * N];
  55.     let mut b = vec![0.0; N * N];
  56.     let mut c_simd = vec![0.0; N * N];
  57.     //let mut c_scalar = vec![0.0; N * N];
  58.  
  59.     // Initialize matrices with random values
  60.     initialize_matrix(&mut a, N);
  61.     initialize_matrix(&mut b, N);
  62.  
  63.     // Benchmark SIMD matrix multiplication
  64.     let start_simd = Instant::now();
  65.     unsafe {
  66.         matrix_multiply_simd(&a, &b, &mut c_simd, N);
  67.     }
  68.     let simd_time = start_simd.elapsed().as_millis();
  69.     println!("SIMD Matrix Multiplication Time: {:.2} ms", simd_time);
  70.  
  71.     // Print a 4x4 block of the resulting matrix to verify correctness
  72.     println!("Verifying a 4x4 block of the resulting matrix:");
  73.     for i in 0..4 {
  74.         for j in 0..4 {
  75.             print!("{:.2} ", c_simd[i * N + j]);
  76.         }
  77.         println!();
  78.     }
  79.     // Benchmark scalar matrix multiplication
  80.     // let start_scalar = Instant::now();
  81.     // matrix_multiply_scalar(&a, &b, &mut c_scalar, N);
  82.     // let scalar_time = start_scalar.elapsed().as_secs_f64() * 1000.0;
  83.     // println!("Scalar Matrix Multiplication Time: {:.2} ms", scalar_time);
  84.  
  85.     // Verify correctness (optional)
  86.     // let mut mismatch = false;
  87.     // for i in 0..N {
  88.     //     for j in 0..N {
  89.     //         if (c_simd[i * N + j] - c_scalar[i * N + j]).abs() > 1e-5 {
  90.     //             println!(
  91.     //                 "Mismatch at ({}, {}): SIMD={}, Scalar={}",
  92.     //                 i, j, c_simd[i * N + j], c_scalar[i * N + j]
  93.     //             );
  94.     //             mismatch = true;
  95.     //         }
  96.     //     }
  97.     // }
  98.     // if !mismatch {
  99.     //     println!("Results match!");
  100.     // }
  101. }
  102.  
Advertisement
Add Comment
Please, Sign In to add comment