hedgefund

zig_matmul_SIMD

Jan 8th, 2025
58
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C 3.42 KB | Source Code | 0 0
  1. const std = @import("std");
  2. const print = std.debug.print;
  3. const math = std.math;
  4. const fabs = std.zig.c_builtins.__builtin_fabsf;
  5. const tMilli = std.time.milliTimestamp;
  6. const rand = std.rand;
  7.  
  8. const N = 1024; // Size of the matrix (N x N)
  9. const simdWidth = 8; // Number of elements processed at a time (AVX2)
  10.  
  11. // Initialize a matrix with random values
  12. fn initializeMatrix(matrix: []f32, size: usize, rng: *rand.DefaultPrng) void {
  13.     for (0..size) |i| {
  14.         for (0..size) |j| {
  15.             matrix[i * size + j] = rng.random().float(f32);
  16.             matrix[i * size + j] = rng.random().float(f32);
  17.         }
  18.     }
  19. }
  20.  
  21. // Perform matrix multiplication using SIMD
  22. fn matrixMultiplySimd(A: []const f32, B: []const f32, C: []f32, size: usize) void {
  23.     for (0..size) |i| {
  24.         var j: usize = 0;
  25.         while (j < size) : (j += simdWidth) {
  26.             var c: @Vector(simdWidth, f32) = @splat(@as(f32, 0.0));
  27.  
  28.             for (0..size) |k| {
  29.                 const a_val = A[i * size + k];
  30.                 const a_vec: @Vector(simdWidth, f32) = @splat(a_val);
  31.                 const b_vec: @Vector(simdWidth, f32) = @as(*align(1) const @Vector(simdWidth, f32), @ptrCast(&B[k * size + j])).*;
  32.                 c += a_vec * b_vec;
  33.             }
  34.  
  35.             @as(*align(1) @Vector(simdWidth, f32), @ptrCast(&C[i * size + j])).* = c;
  36.         }
  37.     }
  38. }
  39.  
  40. // // Perform matrix multiplication without SIMD (for comparison)
  41. // fn matrixMultiplyScalar(A: []const f32, B: []const f32, C: []f32, size: usize) void {
  42. //     for (0..size) |i| {
  43. //         for (0..size) |j| {
  44. //             var sum: f32 = 0.0;
  45. //             for (0..size) |k| {
  46. //                 sum += A[i * size + k] * B[k * size + j];
  47. //             }
  48. //             C[i * size + j] = sum;
  49. //         }
  50. //     }
  51. // }
  52.  
  53. pub fn main() !void {
  54.     // Allocate memory for matrices
  55.     const allocator = std.heap.page_allocator;
  56.     const A = try allocator.alloc(f32, N * N);
  57.     defer allocator.free(A);
  58.     const B = try allocator.alloc(f32, N * N);
  59.     defer allocator.free(B);
  60.     const C_simd = try allocator.alloc(f32, N * N);
  61.     defer allocator.free(C_simd);
  62.     // const C_scalar = try allocator.alloc(f32, N * N);
  63.     // defer allocator.free(C_scalar);
  64.  
  65.     // Initialize matrices with random values
  66.     var rng = rand.DefaultPrng.init(@intCast(tMilli()));
  67.     initializeMatrix(A, N, &rng);
  68.     initializeMatrix(B, N, &rng);
  69.  
  70.     // Benchmark SIMD matrix multiplication
  71.     const start_simd = tMilli();
  72.     matrixMultiplySimd(A, B, C_simd, N);
  73.     const end_simd = tMilli();
  74.     const simd_time = @as(f32, @floatFromInt(end_simd - start_simd));
  75.     print("SIMD Matrix Multiplication Time: {d:.2} ms\n", .{simd_time});
  76.  
  77.     // Benchmark scalar matrix multiplication
  78.     // const start_scalar = tMilli();
  79.     // matrixMultiplyScalar(A, B, C_scalar, N);
  80.     // const end_scalar = tMilli();
  81.     // const scalar_time = @as(f32, @floatFromInt(end_scalar - start_scalar));
  82.     // print("Scalar Matrix Multiplication Time: {d:.2} ms\n", .{scalar_time});
  83.     //
  84.     // // Verify correctness
  85.     // for (0..N) |i| {
  86.     //     for (0..N) |j| {
  87.     //         if (fabs(C_simd[i * N + j] - C_scalar[i * N + j]) > 1e-5) {
  88.     //             print("Mismatch at ({d}, {d}): SIMD={d}, Scalar={d}\n", .{ i, j, C_simd[i * N + j], C_scalar[i * N + j] });
  89.     //             return;
  90.     //         }
  91.     //     }
  92.     // }
  93.     // print("Results match!\n", .{});
  94. }
  95.  
Advertisement
Add Comment
Please, Sign In to add comment