Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <stdio.h>
- #include <mpi.h>
- const int INT_0 = 0;
- const int INT_1 = 1;
- int main(int argc, char *argv[]) {
- int rank, size;
- MPI_Init(&argc, &argv);
- MPI_Comm_rank(MPI_COMM_WORLD, &rank);
- MPI_Comm_size(MPI_COMM_WORLD, &size);
- int info;
- int ictxt, myrow, mycol;
- int nprow = 2, npcol = 2;
- int nb = 2;
- Cblacs_pinfo(&rank, &size);
- Cblacs_get(-1, 0, &ictxt);
- Cblacs_gridinit(&ictxt, "Row", nprow, npcol);
- Cblacs_gridinfo(ictxt, &nprow, &npcol, &myrow, &mycol);
- if (myrow == -1 || mycol == -1) goto print;
- int M = 4, N = 4;
- int mA = numroc_(&M, &nb, &myrow, &INT_0, &nprow);
- int nA = numroc_(&M, &nb, &mycol, &INT_0, &npcol);
- int descA[9];
- descinit_(descA, &M, &N, &nb, &nb, &INT_0, &INT_0, &ictxt, &mA, &info);
- double *A = (double*)malloc(mA * nA * sizeof(double));
- double *work = (double*)malloc(mA * sizeof(double));
- char fn[] = "A.mat";
- // pdlaread(fn, A, descA, &INT_0, &INT_0, work, strlen(fn));
- for (int i = 0; i < mA; ++i)
- for (int j = 0; j < nA; ++j)
- A[j*mA + i] = i*nA + j + 1;
- int mx = numroc_(&M, &nb, &myrow, &INT_0, &nprow);
- int nx = numroc_(&INT_1, &nb, &mycol, &INT_0, &npcol);
- int descx[9];
- descinit_(descx, &M, &INT_1, &nb, &INT_1, &INT_0, &INT_0, &ictxt, &mx, &info);
- double *x = (double*)malloc(mx * INT_1 * sizeof(double));
- int my = numroc_(&M, &nb, &myrow, &INT_0, &nprow);
- int ny = numroc_(&INT_1, &nb, &mycol, &INT_0, &npcol);
- int descy[9];
- descinit_(descy, &M, &INT_1, &nb, &INT_1, &INT_0, &INT_0, &ictxt, &my, &info);
- double *y = (double*)malloc(my * INT_1 * sizeof(double));
- for (int i = 0; i < mx; ++i) x[i] = 1.0;
- double alpha = 1.0, beta = 0.0;
- pdgemv("N", &M, &N, &alpha, A, &INT_1, &INT_1, descA, x, &INT_1, &INT_1,
- descx, &INT_1, &beta, y, &INT_1, &INT_1, descy, &INT_1);
- double norm;
- norm = pdlange("F", &M, &INT_1, y, &INT_1, &INT_1, descy, NULL);
- // printf("%.6lf\n");
- print:
- for (int r = 0; r < size; ++r) {
- if (rank == r) {
- printf("------ %d ------\n", r);
- printf("norma = %.6lf\n", norm);
- if (nx > 0) {
- for (int i = 0; i < mx; ++i) {
- printf("%.6lf\n", y[i]);
- }
- }
- }
- MPI_Barrier(MPI_COMM_WORLD);
- }
- MPI_Finalize();
- return 0;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement