Advertisement
Guest User

Untitled

a guest
Dec 14th, 2018
70
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 14.98 KB | None | 0 0
  1. void *block_reverse (void *arg)
  2. {
  3.   args *my_arg = (args *) arg;
  4.   int idx_p = my_arg->thread_idx,
  5.       p = my_arg->p,
  6.       n = my_arg->n,
  7.       m = my_arg->m,
  8.       s = my_arg->s,
  9.       l = my_arg->l;
  10.   double *mtx = my_arg->mtx,
  11.          *ans = my_arg->ans;
  12.   char *fname = my_arg->sourse;
  13.   double t_time = 0., time = 0., resid_time = 0.;
  14.   static bool err_flag = false;
  15.   static pthread_mutex_t mutex_err = PTHREAD_MUTEX_INITIALIZER;
  16.  
  17.   cpu_set_t cpu;
  18.   CPU_ZERO (&cpu);
  19.   CPU_SET (idx_p, &cpu);
  20.   pthread_setaffinity_np (pthread_self(), sizeof (cpu), &cpu);
  21.  
  22.   for (int i = idx_p; i < l; i += p)
  23.     {
  24.       clean (mtx + i * m * n, m, n);
  25.       clean (ans + i * m * n, m, n);
  26.     }
  27.   if (s && idx_p == l % p)
  28.     {
  29.       clean (mtx + l * m * n, s, n);
  30.       clean (ans + l * m * n, s, n);
  31.     }
  32.   /*
  33.   int i, j, shift, x = n * m, y = m * m;
  34.   for (i = 0; i < l; i++)
  35.     {
  36.       shift = (i + p - 1 - idx_p) / p;
  37.       for (j = idx_p + num_of_threads * shift; j < l; j += p)
  38.         {
  39.           ptr = mtx + j * x + i * y;
  40.           clean (ptr, m, m);
  41.         }
  42.       if (j == l)
  43.         {
  44.           ptr = mtx + l * x + i * s * m;
  45.           clean (ptr, s, m);
  46.         }
  47.       shift = (i + p - idx_p) / p;
  48.       for (j = idx_p + p * shift; j < l; j += p)
  49.         {
  50.           ptr = mtx + i * x + j * y;
  51.           clean (ptr, m, m);
  52.         }
  53.       if (j == l)
  54.         {
  55.           ptr = mtx + i * x + l * y;
  56.           clean (ptr, s, m);
  57.         }
  58.     }
  59.   */
  60.   reduce (p);
  61.   if (fname)
  62.     {
  63.       if (idx_p == 0)
  64.         {
  65.           int ret = read_mtx (mtx, n, m, s, l, fname);
  66.           if (ret < 0)
  67.             {
  68.               my_arg->err = -100;
  69.               switch (ret)
  70.                 {
  71.                   case -1:
  72.                     fprintf (stdout, "Cannot open %s!\n", fname);
  73.                     break;
  74.                   case -2:
  75.                     fprintf (stdout, "Cannot read %s!\n", fname);
  76.                     break;
  77.                   case -3:
  78.                     fprintf (stdout, "%s is empty!\n", fname);
  79.                     break;
  80.                   case -4:
  81.                     fprintf (stdout, "Not enough data in %s!\n", fname);
  82.                     break;
  83.                   default:
  84.                     fprintf (stdout, "Error %d in %s!\n", ret, fname);
  85.                 }
  86.               return 0;
  87.             }
  88.         }
  89.       reduce (p);
  90.     }
  91.   else
  92.     {
  93.       if (idx_p == 0)
  94.         init_mtx (mtx, n, m, s, l, init);
  95.       reduce (p);
  96.     }
  97.   if (idx_p == 0)
  98.     {
  99.       fprintf (stdout, "Matrix A:\n");
  100.       print_mtx (stdout, mtx, n, m, s, l);
  101.     }
  102.   reduce (p);
  103.   static double norm;
  104.   if (idx_p == 0)
  105.     norm = norm_matrix (mtx, n, m, s, l);
  106.   reduce (p);
  107.   my_arg->norm *= norm;
  108.  
  109.   if (idx_p == 0)
  110.     time = get_time();
  111.  
  112.   t_time = get_cpu_time();
  113.   LU_decomp (my_arg);
  114.   for (int i = idx_p; i < l; i += p)
  115.     copy (ans + i * m * n, mtx + i * m * n, m, n);
  116.   if (s && idx_p == l % p)
  117.     copy (ans + l * m * n, mtx + l * m * n, s, n);
  118.   reduce (p);
  119.   fprintf (stdout, "decomp: [%d] = %.3f\n", idx_p, get_cpu_time() - t_time);
  120.   my_arg->time += (get_cpu_time() - t_time);
  121.   t_time = get_cpu_time();
  122.   reverse_L (my_arg);
  123.   fprintf (stdout, "rev_L: [%d] = %.3f\n", idx_p, get_cpu_time() - t_time);
  124.   my_arg->time += (get_cpu_time() - t_time);
  125.   t_time = get_cpu_time();
  126.   reverse_U (my_arg);
  127.   fprintf (stdout, "rev_U: [%d] = %.3f\n", idx_p, get_cpu_time() - t_time);
  128.   my_arg->time += (get_cpu_time() - t_time);
  129.   t_time = get_cpu_time();
  130.   mult_rev_UL (my_arg);
  131.   fprintf (stdout, "mult_UL: [%d] = %.3f\n", idx_p, get_cpu_time() - t_time);
  132.   my_arg->time += (get_cpu_time() - t_time);
  133.  
  134.   if (idx_p == 0)
  135.     time = get_time() - time;
  136.  
  137.   if (my_arg->err != 0)
  138.     {
  139.       pthread_mutex_lock (&mutex_err);
  140.       err_flag = true;
  141.       pthread_mutex_unlock (&mutex_err);
  142.     }
  143.   reduce (p);
  144.   if (!err_flag)
  145.     {
  146.       if (idx_p == 0)
  147.         {
  148.           //time = my_arg->time;
  149.           fprintf (stdout, "\nMatrix A^(-1):\n");
  150.           print_mtx (stdout, ans, n, m, s, l);
  151.           fprintf (stdout, "Elapsed: %.2f\n", time);
  152.         }
  153.       if (!(n > 6000 && p == 1))
  154.         {
  155.           if (idx_p == 0)
  156.             {
  157.               if (fname)
  158.                 read_mtx (mtx, n, m, s, l, fname);
  159.               else
  160.                 init_mtx (mtx, n, m, s, l, init);
  161.             }
  162.           reduce (p);
  163.           resid_time = get_time();
  164.           block_residual (my_arg);
  165.           reduce (p);
  166.           resid_time = get_time() - resid_time;
  167.           if (idx_p == 0)
  168.             {
  169.               fprintf (stdout, "Residual = %e\n", my_arg->norm);
  170.               fprintf (stdout, "Elapsed for residual: %.2f\n", resid_time);
  171.             }
  172.         }
  173.     }
  174.   reduce (p);
  175.   return 0;
  176. }
  177.  
  178. void LU_decomp (args *my_arg)
  179. {
  180.   double *mtx = my_arg->mtx;
  181.   int n = my_arg->n,
  182.       m = my_arg->m,
  183.       s = my_arg->s,
  184.       l = my_arg->l;
  185.   int idx_p = my_arg->thread_idx,
  186.       p = my_arg->p;
  187.   double *ptr = mtx, *ptr1 = mtx, *ptr2 = mtx, *diag = mtx;
  188.   double *buf = new double [m * m];
  189.   double *sum = new double [m * m];
  190.   double *rev_diag = new double [m * m];
  191.  
  192.   copy (rev_diag, diag, m, m);
  193.   if (mtx_reverse (rev_diag, buf, m, my_arg->norm) < 0)
  194.     {
  195.       my_arg->err = -1;
  196.     }
  197.   copy (rev_diag, buf, m, m);
  198.   /*
  199.   if (idx_p == 0)
  200.     {
  201.       for (int i = 1; i < l; i++)
  202.         {
  203.           ptr += m * m;
  204.           mtx_prod (rev_diag, ptr, buf, m, m, m);
  205.           copy (ptr, buf, m, m);
  206.         }
  207.       if (s)
  208.         {
  209.           ptr += m * m;
  210.           mtx_prod (rev_diag, ptr, buf, m, m, s);
  211.           copy (ptr, buf, m, s);
  212.         }
  213.     }
  214.   */
  215.   for (int i = 1 + idx_p; i < l; i += p)
  216.     {
  217.       ptr = mtx + m * m * i;
  218.       mtx_prod (rev_diag, ptr, buf, m, m, m);
  219.       copy (ptr, buf, m, m);
  220.     }
  221.   if (s && idx_p == 0)
  222.     {
  223.       ptr = mtx + l * m * m;
  224.       mtx_prod (rev_diag, ptr, buf, m, m, s);
  225.       copy (ptr, buf, m, s);
  226.     }
  227.   //
  228.   reduce (p);
  229.   diag += m * n + m * m;
  230.   clean (buf, m, m);
  231.   for (int i = 1 ; i < l ; i++)
  232.     {
  233.       int new_i = i + idx_p;
  234.       int board_thread = (l - i) % p;
  235.       ptr2 = mtx + new_i * m * n;
  236.       for (int j = new_i ; j < l ; j += p)
  237.         {
  238.           clean (sum, m, m);
  239.           ptr1 = mtx + i * m * m;
  240.           for (int r = 0, k = 0 ; k < i ; r += m * m, k++)
  241.             {
  242.               mtx_prod (ptr2 + r, ptr1, buf, m, m, m);
  243.               mtx_add (sum, buf, sum, m, m);
  244.               ptr1 += m * n;
  245.             }
  246.           mtx_dif (ptr2 + i * m * m, sum, ptr2 + i * m * m, m, m);
  247.           ptr2 += p * m * n;
  248.         }
  249.       if (s && idx_p == board_thread)
  250.         {
  251.           clean (sum, s, m);
  252.           ptr1 = mtx + i * m * m;
  253.           for (int r = 0, k = 0 ; k < i ; r += m * s, k++)
  254.             {
  255.               mtx_prod (ptr2 + r, ptr1, buf, s, m, m);
  256.               mtx_add (sum, buf, sum, s, m);
  257.               ptr1 += m * n;
  258.             }
  259.           mtx_dif (ptr2 + i * m * s, sum, ptr2 + i * m * s, s, m);
  260.         }
  261.       reduce (p);
  262.       copy (rev_diag, diag, m, m);
  263.       if (mtx_reverse (rev_diag, buf, m, my_arg->norm) < 0)
  264.          {
  265.            my_arg->err = -1;
  266.            break;
  267.          }
  268.       copy (rev_diag, buf, m, m);
  269.       ptr2 = mtx + i * m * n;
  270.       for (int k = new_i + 1, r = (new_i + 1) * m * m; k < l ; k += p, r += m * m * p)
  271.         {
  272.           clean (sum, m, m);
  273.           ptr1 = mtx;
  274.           for (int j = 0, row_i = 0 ; j < i ; j++, row_i += m * m)
  275.             {
  276.               mtx_prod (ptr2 + row_i, ptr1 + r, buf, m, m, m);
  277.               mtx_add (sum, buf, sum, m, m);
  278.               ptr1 += m * n;
  279.             }
  280.           mtx_dif (ptr2 + r, sum, buf, m, m);
  281.           mtx_prod (rev_diag, buf, ptr2 + r, m, m, m);
  282.         }
  283.       if (s && idx_p == board_thread)
  284.         {
  285.           clean (sum, m, s);
  286.           ptr1 = mtx;
  287.           for (int r = 0, k = 0 ; k < i ; r += m * m, k++)
  288.             {
  289.               mtx_prod (ptr2 + r, ptr1 + l * m * m, buf, m, m, s);
  290.               mtx_add (sum, buf, sum, m, s);
  291.               ptr1 += m * n;
  292.             }
  293.           mtx_dif (ptr2 + l * m * m, sum, buf, m, s);
  294.           mtx_prod (rev_diag, buf, ptr2 + l * m * m, m, m, s);
  295.         }
  296.       diag += m * n + m * m;
  297.       reduce (p);
  298.     }
  299.   if (s && idx_p == l % p)
  300.     {
  301.       ptr2 = mtx + l * m * n;
  302.       clean (sum, s, s);
  303.       ptr1 = mtx + l * m * m;
  304.       for (int r = 0, k = 0 ; k < l ; r += m * s, k++)
  305.         {
  306.           mtx_prod (ptr2 + r, ptr1, buf, s, m, s);
  307.           mtx_add (sum, buf, sum, s, s);
  308.           ptr1 += m * n;
  309.         }
  310.       mtx_dif (ptr2 + l * m * s, sum, ptr2 + l * m * s, s, s);
  311.     }
  312.   delete [] buf;
  313.   delete [] sum;
  314.   delete [] rev_diag;
  315.   reduce (p);
  316.   return;
  317. }
  318.  
  319. void reverse_L (args *my_arg)
  320. {
  321.   double *mtx = my_arg->mtx,
  322.          *ans = my_arg->ans;
  323.   int n = my_arg->n,
  324.       m = my_arg->m,
  325.       s = my_arg->s,
  326.       l = my_arg->l;
  327.   int idx_p = my_arg->thread_idx,
  328.       p = my_arg->p;
  329.   double *diag = mtx, *ptr1 = mtx, *ptr2 = mtx;
  330.   double *rev_diag, *buf, *sum;
  331.   buf = new double [m * m];
  332.   sum = new double [m * m];
  333.   rev_diag = new double [m * m];
  334.   double *whole_row  = new double [m * n];
  335.   for (int i = 0 ; i < l ; i++)
  336.     {
  337.       copy (rev_diag, diag, m, m);
  338.       if (mtx_reverse (rev_diag, buf, m, my_arg->norm) < 0)
  339.         {
  340.           my_arg->err = -1;
  341.         }
  342.       if (idx_p == i % p)
  343.         {
  344.           copy (diag, buf, m, m);
  345.         }
  346.       reduce (p);
  347.       copy (whole_row, ans, m, n);
  348.       for (int j = idx_p, k = idx_p * m * m; j < i; j += p, k += m * m * p)
  349.         {
  350.           clean (sum, m, m);
  351.           ptr2 = mtx + j * m * n;
  352.           for (int cnt = j, r = j * m * m; cnt < i; cnt++, r += m * m)
  353.             {
  354.               mtx_prod (whole_row + r, ptr2 + k, buf, m, m, m);
  355.               mtx_add (sum, buf, sum, m, m);
  356.               ptr2 += m * n;
  357.             }
  358.           clean (ptr1 + k, m, m);
  359.           mtx_dif (ptr1 + k, sum, buf, m, m);
  360.           mtx_prod (diag, buf, ptr1 + k, m, m, m);
  361.         }
  362.       ptr1 += m * n;
  363.       ans += m * n;
  364.       diag += m * n + m * m;
  365.       reduce (p);
  366.     }
  367.   diag = mtx + l * (m * n + m * s);
  368.   if (s)
  369.     {
  370.       copy (rev_diag, diag, s, s);
  371.       if (mtx_reverse (rev_diag, buf, s, my_arg->norm) < 0)
  372.         {
  373.           my_arg->err = -1;
  374.         }
  375.       if (idx_p == l % p)
  376.         {
  377.           copy (diag, buf, s, s);
  378.         }
  379.       reduce (p);
  380.       for (int j = idx_p, k = idx_p * m * m; j < l; j += p, k += m * m * p)
  381.         {
  382.           clean (sum, m, m);
  383.           ptr2 = mtx + j * m * n;
  384.           for (int cnt = j, r = j * m * s; cnt < l; cnt++, r += m * s)
  385.             {
  386.               mtx_prod (ans + r, ptr2 + k, buf, s, m, m);
  387.               mtx_add (sum, buf, sum, s, m);
  388.               ptr2 += m * n;
  389.             }
  390.           clean (ptr1 + j * m * s, s, m);
  391.           mtx_dif (ptr1 + j * m * s, sum, buf, s, m);
  392.           mtx_prod (diag, buf, ptr1 + j * m * s, s, s, m);
  393.         }
  394.     }
  395.   delete [] buf;
  396.   delete [] sum;
  397.   delete [] rev_diag;
  398.   delete [] whole_row;
  399.   reduce (p);
  400.   return;
  401. }
  402.  
  403. void reverse_U (args *my_arg)
  404. {
  405.   double *mtx = my_arg->mtx,
  406.          *ans = my_arg->ans;
  407.   int n = my_arg->n,
  408.       m = my_arg->m,
  409.       s = my_arg->s,
  410.       l = my_arg->l;
  411.   int idx_p = my_arg->thread_idx,
  412.       p = my_arg->p;
  413.   double *ptr1 = mtx, *ptr2 = mtx;
  414.   double *buf, *sum;
  415.   buf = new double [m * m];
  416.   sum = new double [m * m];
  417.   for (int i = idx_p ; i < l ; i += p)
  418.     {
  419.       ptr1 = mtx + i * m * n;
  420.       for (int j = i + 1, k = (i + 1) * m * m; j < l; j++, k += m * m)
  421.         {
  422.           clean (sum, m, m);
  423.           ptr2 = ans + i * m * n;
  424.           mtx_add (sum, ptr2 + k, sum, m, m);
  425.           ptr2 += m * n;
  426.           for (int cnt = i + 1, r = (i + 1) * m * m; cnt < j; cnt++, r += m * m)
  427.             {
  428.               mtx_prod (ptr1 + r, ptr2 + k, buf, m, m, m);
  429.               mtx_add (sum, buf, sum, m, m);
  430.               ptr2 += m * n;
  431.             }
  432.           clean (ptr1 + k, m, m);
  433.           mtx_dif (ptr1 + k, sum, ptr1 + k, m, m);
  434.         }
  435.       if (s)
  436.         {
  437.           clean (sum, m, s);
  438.           ptr2 = ans + i * m * n;
  439.           mtx_add (sum, ptr2 + l * m * m, sum, m, s);
  440.           ptr2 += m * n;
  441.           for (int cnt = i + 1, r = (i + 1) * m * m; cnt < l; cnt++, r += m * m)
  442.             {
  443.               mtx_prod (ptr1 + r, ptr2 + l * m * m, buf, m, m, s);
  444.               mtx_add (sum, buf, sum, m, s);
  445.               ptr2 += m * n;
  446.             }
  447.           clean (ptr1 + l * m * m, m, s);
  448.           mtx_dif (ptr1 + l * m * m, sum, ptr1 + l * m * m, m, s);
  449.         }
  450.     }
  451.   delete [] buf;
  452.   delete [] sum;
  453.   reduce (p);
  454.   return;
  455. }
  456.  
  457. void mult_rev_UL (args *my_arg)
  458. {
  459.   double *mtx = my_arg->mtx,
  460.          *ans = my_arg->ans;
  461.   int n = my_arg->n,
  462.       m = my_arg->m,
  463.       s = my_arg->s,
  464.       l = my_arg->l;
  465.   int idx_p = my_arg->thread_idx,
  466.       p = my_arg->p;
  467.   double *ptr_L = mtx, *ptr_U = mtx, *ptr_buf = ans;
  468.   double *sum, *buf;
  469.   buf = new double [m * m];
  470.   sum = new double [m * m];
  471.   for (int i = 0, row_i = 0; i < l; i++, row_i += m * n)
  472.     {
  473.       for (int j = idx_p, k = idx_p * m * m; j <= i; j += p, k += m * m * p)
  474.         {
  475.           clean (sum, m, m);
  476.           ptr_L = mtx + (i + 1) * m * n;
  477.           for (int cnt = i + 1, r = (i + 1) * m * m; cnt < l; cnt++, r += m * m)
  478.             {
  479.               mtx_prod (ptr_U + r, ptr_L + k, buf, m, m, m);
  480.               mtx_add (sum, buf, sum, m, m);
  481.               ptr_L += m * n;
  482.             }
  483.           if (s)
  484.             {
  485.               mtx_prod (ptr_U + l * m * m, ptr_L + j * m * s, buf, m, s, m);
  486.               mtx_add (sum, buf, sum, m, m);
  487.             }
  488.           mtx_add (mtx + row_i + k, sum, ptr_buf + k, m, m);
  489.         }
  490.       for (int j = i + 1 + idx_p, k = (i + 1 + idx_p) * m * m; j < l; j += p, k += m * m * p)
  491.         {
  492.           clean (sum, m, m);
  493.           ptr_L = mtx + j * m * n;
  494.           for (int cnt = j, r = k; cnt < l; cnt++, r += m * m)
  495.             {
  496.               mtx_prod (ptr_U + r, ptr_L + k, buf, m, m, m);
  497.               mtx_add (sum, buf, sum, m, m);
  498.               ptr_L += m * n;
  499.             }
  500.           if (s)
  501.             {
  502.               mtx_prod (ptr_U + l * m * m, ptr_L + j * m * s, buf, m, s, m);
  503.               mtx_add (sum, buf, sum, m, m);
  504.               ptr_L += m * n;
  505.             }
  506.           copy (ptr_buf + k, sum, m, m);
  507.         }
  508.       if (s && idx_p == i % p)
  509.         {
  510.           clean (sum, m, m);
  511.           mtx_prod (ptr_U + l * m * m, mtx + l * m * n + l * m * s, buf, m, s, s);
  512.           mtx_add (sum, buf, ptr_buf + l * m * m, m, s);
  513.         }
  514.       ptr_U += m * n;
  515.       ptr_buf += m * n;
  516.       reduce (p);
  517.     }
  518.   if (s && idx_p == l % p)
  519.     {
  520.       ptr_L = mtx + l * m * n;
  521.       for (int j = 0, k = 0; j < l; j++, k += s * m)
  522.         copy (ptr_buf + k, ptr_L + k, s, m);
  523.       copy (ptr_buf + l * s * m, ptr_L + l * s * m, s, s);
  524.     }
  525.   delete [] buf;
  526.   delete [] sum;
  527.   reduce (p);
  528.   return;
  529. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement