zhangsongcui

simd_complex with test

Jun 8th, 2012
237
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 3.62 KB | None | 0 0
  1. #include <iostream>
  2. #include <complex>
  3. #include <pmmintrin.h>
  4. #include <chrono>
  5.  
  6. template <typename T>
  7. struct simd_complex;
  8.  
  9. template <>
  10. struct simd_complex<double> {
  11.     simd_complex(): value(_mm_setzero_pd()) {}
  12.     simd_complex(double r): value(_mm_set_sd(r)) {}
  13.     simd_complex(double r, double i): value(_mm_setr_pd(r, i)) {}
  14.     simd_complex(__m128d v): value(v) {}
  15.  
  16.     operator __m128d() { return value; }
  17.  
  18.     friend
  19.     simd_complex operator +(simd_complex a_b, simd_complex c_d) {
  20.         return _mm_add_pd(a_b, c_d);
  21.     }
  22.     simd_complex& operator +=(simd_complex rhs) {
  23.         return *this = *this + rhs;
  24.     }
  25.  
  26.     friend
  27.     simd_complex operator -(simd_complex a_b, simd_complex c_d) {
  28.         return _mm_sub_pd(a_b, c_d);
  29.     }
  30.     simd_complex& operator -=(simd_complex rhs) {
  31.         return *this = *this - rhs;
  32.     }
  33.  
  34.     friend
  35.     simd_complex operator *(simd_complex a_b, simd_complex c_d) {
  36.         __m128d c_c = _mm_movedup_pd(c_d);
  37.         __m128d ac_bc = _mm_mul_pd(a_b, c_c);
  38.  
  39.         __m128d b_a = _mm_shuffle_pd(a_b, a_b, 1);
  40.         __m128d d_d = _mm_unpackhi_pd(c_d, c_d);
  41.         __m128d bd_ad = _mm_mul_pd(b_a, d_d);
  42.         return _mm_addsub_pd(ac_bc, bd_ad);
  43.     }
  44.     simd_complex& operator *=(simd_complex rhs) {
  45.         return *this = *this * rhs;
  46.     }
  47.  
  48.     friend
  49.     simd_complex operator /(simd_complex a_b, simd_complex c_d) {
  50.         __m128d b_a = _mm_shuffle_pd(a_b, a_b, 1);
  51.         __m128d c_c = _mm_movedup_pd(c_d);
  52.         __m128d bc_ac = _mm_mul_pd(b_a, c_c);
  53.  
  54.         __m128d d_d = _mm_unpackhi_pd(c_d, c_d);
  55.         __m128d ad_bd = _mm_mul_pd(a_b, d_d);
  56.  
  57.         __m128d t = _mm_addsub_pd(bc_ac, ad_bd);  // (bc-ad) + (ac+bd)i
  58.         __m128d numerator = _mm_shuffle_pd(t, t, 1); // (ac+bd) + (bc-ad)i
  59.         __m128d cc_dd = _mm_mul_pd(c_d, c_d);
  60.         __m128d denominator = _mm_hadd_pd(cc_dd, cc_dd); // (cc+dd) + (cc+dd)i
  61.         return _mm_div_pd(numerator, denominator);
  62.     }
  63.     simd_complex operator /=(simd_complex rhs) {
  64.         return *this = *this / rhs;
  65.     }
  66.  
  67.     double real() { return reinterpret_cast<double (&)[2]>(value)[0]; }
  68.     double imag() { return reinterpret_cast<double (&)[2]>(value)[1]; }
  69.  
  70.     friend std::ostream& operator <<(std::ostream& o, simd_complex rhs) {
  71.         return o << '(' << rhs.real() << ',' << rhs.imag() << ')';
  72.     }
  73.  
  74. private:
  75.     __m128d value; // [0]: real, [1]: imag
  76. };
  77.  
  78. double abs(simd_complex<double> a_b) {
  79.     __m128d aa_bb = _mm_mul_pd(a_b, a_b);
  80.     __m128d aaAbb_aaAbb = _mm_hadd_pd(aa_bb, aa_bb);
  81.     __m128d result = _mm_sqrt_pd(aaAbb_aaAbb);
  82.     return reinterpret_cast<double (&)[2]>(result)[0];
  83. }
  84.  
  85. int main(void) {
  86.     using namespace std;
  87.     double a[4];
  88.     for (int i = 0; i < 4; ++i)
  89.         cin >> a[i];
  90.     auto st = chrono::steady_clock::now();
  91.     complex<double> c1(a[0], a[1]), c2(a[2], a[3]), c3(1, 1), c4(1, 1), c5, c6;
  92.     for (int i = 0; i < 100000000; ++i) {
  93.         c3 *= c1;
  94.         c4 /= c2;
  95.         c5 += c1;
  96.         c6 -= c2;
  97.     }
  98.     auto t = chrono::steady_clock::now() - st;
  99.     cout << c1 * c2 << endl << c1 / c2 << endl;
  100.     cout << c3 << c4 << c5 << c6 << endl << t.count() << endl;
  101.  
  102.     st = chrono::steady_clock::now();
  103.     simd_complex<double> d1(a[0], a[1]), d2(a[2], a[3]), d3(1, 1), d4(1, 1), d5, d6;
  104.     for (int i = 0; i < 100000000; ++i) {
  105.         d3 *= d1;
  106.         d4 /= d2;
  107.         d5 += d1;
  108.         d6 -= d2;
  109.     }
  110.     t = chrono::steady_clock::now() - st;
  111.     cout << d1 * d2 << endl << d1 / d2 << endl;
  112.     cout << d3 << d4 << d5 << d6 << endl << t.count() << endl;
  113.  
  114.     cout << abs(simd_complex<double>(3, 4)) << endl;
  115. }
Advertisement
Add Comment
Please, Sign In to add comment