Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- use core::arch::x86_64::{
- _mm256_add_ps, _mm256_loadu_ps, _mm256_mul_ps, _mm256_set1_ps, _mm256_setzero_ps,
- _mm256_storeu_ps,
- };
- use rand::Rng;
- use std::time::Instant;
- const N: usize = 1024; // Size of the matrix (N x N)
- const SIMD_WIDTH: usize = 8; // Number of elements processed at a time (AVX2)
- // Function to initialize a matrix with random values
- fn initialize_matrix(matrix: &mut [f32], size: usize) {
- let mut rng = rand::thread_rng();
- for i in 0..size {
- for j in 0..size {
- matrix[i * size + j] = rng.gen::<f32>();
- }
- }
- }
- // Function to perform matrix multiplication using SIMD
- unsafe fn matrix_multiply_simd(a: &[f32], b: &[f32], c: &mut [f32], size: usize) {
- for i in 0..size {
- for j in (0..size).step_by(SIMD_WIDTH) {
- let mut c_simd = _mm256_setzero_ps();
- for k in 0..size {
- let a_val = _mm256_set1_ps(a[i * size + k]);
- let b_simd = _mm256_loadu_ps(b.as_ptr().add(k * size + j));
- c_simd = _mm256_add_ps(c_simd, _mm256_mul_ps(a_val, b_simd));
- }
- _mm256_storeu_ps(c.as_mut_ptr().add(i * size + j), c_simd);
- }
- }
- }
- // Function to perform matrix multiplication without SIMD (for comparison)
- //#[allow(dead_code)]
- //fn matrix_multiply_scalar(a: &[f32], b: &[f32], c: &mut [f32], size: usize) {
- // for i in 0..size {
- // for j in 0..size {
- // let mut sum = 0.0;
- // for k in 0..size {
- // sum += a[i * size + k] * b[k * size + j];
- // }
- // c[i * size + j] = sum;
- // }
- // }
- //}
- fn main() {
- // Allocate memory for matrices
- let mut a = vec![0.0; N * N];
- let mut b = vec![0.0; N * N];
- let mut c_simd = vec![0.0; N * N];
- //let mut c_scalar = vec![0.0; N * N];
- // Initialize matrices with random values
- initialize_matrix(&mut a, N);
- initialize_matrix(&mut b, N);
- // Benchmark SIMD matrix multiplication
- let start_simd = Instant::now();
- unsafe {
- matrix_multiply_simd(&a, &b, &mut c_simd, N);
- }
- let simd_time = start_simd.elapsed().as_millis();
- println!("SIMD Matrix Multiplication Time: {:.2} ms", simd_time);
- // Print a 4x4 block of the resulting matrix to verify correctness
- println!("Verifying a 4x4 block of the resulting matrix:");
- for i in 0..4 {
- for j in 0..4 {
- print!("{:.2} ", c_simd[i * N + j]);
- }
- println!();
- }
- // Benchmark scalar matrix multiplication
- // let start_scalar = Instant::now();
- // matrix_multiply_scalar(&a, &b, &mut c_scalar, N);
- // let scalar_time = start_scalar.elapsed().as_secs_f64() * 1000.0;
- // println!("Scalar Matrix Multiplication Time: {:.2} ms", scalar_time);
- // Verify correctness (optional)
- // let mut mismatch = false;
- // for i in 0..N {
- // for j in 0..N {
- // if (c_simd[i * N + j] - c_scalar[i * N + j]).abs() > 1e-5 {
- // println!(
- // "Mismatch at ({}, {}): SIMD={}, Scalar={}",
- // i, j, c_simd[i * N + j], c_scalar[i * N + j]
- // );
- // mismatch = true;
- // }
- // }
- // }
- // if !mismatch {
- // println!("Results match!");
- // }
- }
Advertisement
Add Comment
Please, Sign In to add comment