Advertisement
Guest User

fftw

a guest
Feb 27th, 2020
123
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 5.44 KB | None | 0 0
  1. #include <fftw3-mpi.h>
  2. #include <mpi.h>
  3. #include <iostream>
  4. #include <cmath>
  5. #include <functional>
  6.  
  7. using namespace std;
  8.  
  9. static double main_function(const double x1, const double x2, const double x3) {
  10.     return exp(sin(x1 - x2 + x3));
  11. }
  12. static double derivative_x1(const double x1, const double x2, const double x3) {
  13.     return exp(sin(x1 - x2 + x3)) * cos(x1 - x2 + x3);
  14. }
  15. static double derivative_x2(const double x1, const double x2, const double x3) {
  16.     return -exp(sin(x1 - x2 + x3)) * cos(x1 - x2 + x3);
  17. }
  18. static double derivative_x3(const double x1, const double x2, const double x3) {
  19.     return exp(sin(x1 - x2 + x3)) * cos(x1 - x2 + x3);
  20. }
  21.  
  22. struct FFTW_ptrs {
  23.     ptrdiff_t alloc_local, local_n0, local_0_start;
  24.     FFTW_ptrs(ptrdiff_t _alloc_local, ptrdiff_t _local_n0, ptrdiff_t _local_0_start)
  25.     : alloc_local(_alloc_local),
  26.       local_n0(_local_n0),
  27.       local_0_start(_local_0_start)
  28.     {}
  29. };
  30.  
  31. void find_fourie_derivative(function<double(const double, const double, const double)> *derivatives,
  32.                             const ptrdiff_t N,
  33.                             FFTW_ptrs ptrs,
  34.                             const double range_from,
  35.                             const double range_to)
  36. {
  37.     fftw_plan forward, backward;
  38.     double *real          = fftw_alloc_real(ptrs.alloc_local * 2);
  39.     fftw_complex *complex = fftw_alloc_complex(ptrs.alloc_local);
  40.     forward  = fftw_mpi_plan_dft_r2c_3d(N, N, N, real, complex, MPI_COMM_WORLD, FFTW_MEASURE);
  41.     backward = fftw_mpi_plan_dft_c2r_3d(N, N, N, complex, real, MPI_COMM_WORLD, FFTW_MEASURE);
  42.     double coefficient = 0;
  43.     auto indices = new double[N];
  44.     for (ptrdiff_t i = 0; i <= N / 2; i++) {
  45.         indices[i] = i;
  46.     }
  47.     for (ptrdiff_t i = N / 2 + 1; i < N; i++) {
  48.         indices[i] = i - N;
  49.     }
  50.     for (int parameter = 0; parameter < 3; parameter++) {
  51.         //filling values of main function
  52.         for (ptrdiff_t i = 0; i < ptrs.local_n0; i++) {
  53.             for (ptrdiff_t j = 0; j < N; j++) {
  54.                 for (ptrdiff_t k = 0; k < N; k++) {
  55.                     double current_x1 = range_to * (ptrs.local_0_start + i) / N;
  56.                     double current_x2 = range_to * j / N;
  57.                     double current_x3 = range_to * k / N;
  58.                     real[(i * N + j) * (2 * (N / 2 + 1)) + k] = main_function(current_x1, current_x2, current_x3);
  59.                 }
  60.             }
  61.         }
  62.         fftw_execute(forward);
  63.  
  64.         for (ptrdiff_t i = 0; i < ptrs.local_n0; i++) {
  65.             for (ptrdiff_t j = 0; j < N; j++) {
  66.                 for (ptrdiff_t k = 0; k < N/2 + 1; k++) {
  67.                     switch (parameter) {
  68.                         case 0:
  69.                             coefficient = indices[ptrs.local_0_start + i];
  70.                             break;
  71.                         case 1:
  72.                             coefficient = indices[j];
  73.                             break;
  74.                         case 2:
  75.                             coefficient = k;
  76.                         default:
  77.                             break;
  78.                     }
  79.                     complex[(i * N + j) * (2 * (N / 2 + 1)) + k][0] *=  coefficient;
  80.                     complex[(i * N + j) * (2 * (N / 2 + 1)) + k][1] *= -coefficient;
  81.                     swap(complex[(i * N + j) * (2 * (N / 2 + 1)) + k][0],
  82.                             complex[(i * N + j) * (2 * (N / 2 + 1)) + k][1]);
  83.                 }
  84.             }
  85.         }
  86.  
  87.         fftw_execute(backward);
  88.  
  89.         for (ptrdiff_t i = 0; i < ptrs.local_n0; i++)
  90.             for (ptrdiff_t j = 0; j < N; j++)
  91.                 for (ptrdiff_t k = 0; k < N; k++)
  92.                     real[(i * N + j) * (2 * (N / 2 + 1)) + k] /= N*N*N;
  93.  
  94.         double max_diff = 0.0;
  95.  
  96.         for (ptrdiff_t i = 0; i < ptrs.local_n0; i++)
  97.             for (ptrdiff_t j = 0; j < N; j++)
  98.                 for (ptrdiff_t k = 0; k < N; k++) {
  99.                     double current_x1 = range_to * (ptrs.local_0_start + i) / N;
  100.                     double current_x2 = range_to * j / N;
  101.                     double current_x3 = range_to * k / N;
  102.                     max_diff = max(abs(real[(i * N + j) * (2 * (N / 2 + 1)) + k] -
  103.                                        derivatives[parameter](current_x1, current_x2, current_x3)), max_diff);
  104.                 }
  105.         cout << "Difference in derivative for x1: " << max_diff << endl;
  106.     }
  107.     fftw_free(real);
  108.     fftw_free(complex);
  109.     fftw_destroy_plan(forward);
  110.     fftw_destroy_plan(backward);
  111. }
  112.  
  113. int main(int argc, char **argv)
  114. {
  115.     if (argc != 2) {
  116.         cout << "Wrong number of arguments, program will be terminated\n";
  117.         return -1;
  118.     }
  119.     const ptrdiff_t N = strtol(argv[1], nullptr, 10);
  120.     ptrdiff_t alloc_local, local_n0, local_0_start;
  121.     alloc_local = fftw_mpi_local_size_3d(N, N, N / 2 + 1,
  122.                                          MPI_COMM_WORLD, &local_n0, &local_0_start);
  123.  
  124.     MPI_Init(&argc, &argv);
  125.     int rank, size;
  126.     MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  127.     MPI_Comm_size(MPI_COMM_WORLD, &size);
  128.     fftw_mpi_init();
  129.     FFTW_ptrs ptrs(alloc_local, local_n0, local_0_start);
  130.     const double rangeFrom = 0.0, rangeTo = 2 * M_PI;
  131.     function<double(const double, const double, const double)> derivatives[3];
  132.     derivatives[0] = derivative_x1;
  133.     derivatives[1] = derivative_x2;
  134.     derivatives[2] = derivative_x3;
  135.     find_fourie_derivative(derivatives, N, ptrs, rangeFrom, rangeTo);
  136.     MPI_Finalize();
  137.     return 0;
  138. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement