Advertisement
Guest User

quantify-kmeans.c

a guest
Feb 19th, 2013
279
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C 6.51 KB | None | 0 0
  1. #include <assert.h>
  2. #include <stdio.h>
  3. #include <stdlib.h>
  4. #include <string.h>
  5. #include <time.h>
  6.  
  7. #define D 3                       /* number of dimensions (3 for RGB) */
  8.  
  9. static int width, height;         /* original image width and height */
  10. static int N;                     /* number of samples (width x height) */
  11. static int K;                     /* number of clusters to create */
  12. static unsigned char (*data)[D];  /* N coordinates of each point */
  13. static unsigned char (*best)[D];  /* N best quantified coordinates (so far) */
  14. static int *idx;                  /* N index of nearest centroid per point */
  15. static int *err;                  /* N squared distance to nearest centroid */
  16. static unsigned char (*cntr)[D];  /* K coordinates of each centroid */
  17. static int *size;                 /* K number of points nearest per centroid */
  18. static long long (*sum)[D];       /* K sum of coords of points near centroid */
  19.  
  20. static void allocate()
  21. {
  22.     data = malloc(N*sizeof(*data));
  23.     assert(data != NULL);
  24.     best = malloc(N*sizeof(*best));
  25.     assert(best != NULL);
  26.     idx = malloc(N*sizeof(*idx));
  27.     assert(idx != NULL);
  28.     err = malloc(N*sizeof(*err));
  29.     assert(err != NULL);
  30.     cntr = malloc(K*sizeof(*cntr));
  31.     assert(cntr != NULL);
  32.     size = malloc(K*sizeof(*size));
  33.     assert(size != NULL);
  34.     sum = malloc(K*sizeof(*sum));
  35.     assert(sum != NULL);
  36. }
  37.  
  38. /* Read input.  Initializes `data'. */
  39. static void read_input(const char *path)
  40. {
  41.     int i, d, res, val;
  42.     FILE *fp;
  43.  
  44.     fp = fopen(path, "rb");
  45.     if (fp == NULL)
  46.     {
  47.         perror(path);
  48.         exit(EXIT_FAILURE);
  49.     }
  50.     assert(fp != NULL);
  51.     res = fscanf(fp, "P6 %d %d %d%*1[\n]", &width, &height, &val);
  52.     assert(res == 3);
  53.     assert(width > 0);
  54.     assert(height > 0);
  55.     assert(val > 0 && val < 256);
  56.     N = width*height;
  57.     allocate();
  58.     res = fread(data, 3, N, fp);
  59.     assert(res == N);
  60.     fclose(fp);
  61. }
  62.  
  63. static void write_output(const char *path)
  64. {
  65.     FILE *fp;
  66.  
  67.     fp = fopen(path, "wb");
  68.     if (fp == NULL)
  69.     {
  70.         perror(path);
  71.         exit(EXIT_FAILURE);
  72.     }
  73.     fprintf(fp, "P6\n%d %d\n%d\n", width, height, 255);
  74.     fwrite(best, 3, N, fp);
  75.     fclose(fp);
  76. }
  77.  
  78. static void randomize_centroid(int k)
  79. {
  80. #if 0
  81.     int i, d;
  82.  
  83.     /* select random data point */
  84.     assert(RAND_MAX >= N);
  85.     i = rand()%N;
  86.     for (d = 0; d < D; ++d)
  87.     {
  88.         cntr[k][d] = data[i][d];
  89.     }
  90.  
  91. #else
  92.     int d;
  93.  
  94.     /* generate uniform random point in domain */
  95.     assert(RAND_MAX >= 255);
  96.     for (d = 0; d < D; ++d)
  97.     {
  98.         cntr[k][d] = rand()%256;
  99.     }
  100. #endif
  101. }
  102.  
  103. static void randomize_centroids()
  104. {
  105.     int k;
  106.  
  107.     for (k = 0; k < K; ++k)
  108.     {
  109.         randomize_centroid(k);
  110.     }
  111. }
  112.  
  113. /* Assigns points to nearest centroids.  Updates `idx' and `err'. */
  114. static long long cluster_points()
  115. {
  116.     int i, k, d, e, f, best_k, best_e;
  117.  
  118.     #pragma omp parallel for private(i, k, d, e, f, best_k, best_e)
  119.     for (i = 0; i < N; ++i)
  120.     {
  121.         best_e = 256*256*D;
  122.         best_k = -1;
  123.         for (k = 0; k < K; ++k)
  124.         {
  125.             e = 0;
  126.             for (d = 0; d < D; ++d)
  127.             {
  128.                 f = data[i][d] - cntr[k][d];
  129.                 e += f*f;
  130.             }
  131.             if (e < best_e)
  132.             {
  133.                 best_k = k;
  134.                 best_e = e;
  135.             }
  136.         }
  137.         assert(best_k >= 0);
  138.         idx[i] = best_k;
  139.         err[i] = best_e;
  140.     }
  141. }
  142.  
  143. /* Calculates per-cluster distortion.  Updates `size` and `sum'.
  144.    Returns total distortion. */
  145. static long long calculate_distortion()
  146. {
  147.     long long distortion;
  148.     int i, k, d;
  149.  
  150.     distortion = 0;
  151.     for (k = 0; k < K; ++k)
  152.     {
  153.         size[k] = 0;
  154.         for (d = 0; d < D; ++d)
  155.         {
  156.             sum[k][d] = 0;
  157.         }
  158.     }
  159.     for (i = 0; i < N; ++i)
  160.     {
  161.         k = idx[i];
  162.         for (d = 0; d < D; ++d)
  163.         {
  164.             sum[k][d] += data[i][d];
  165.         }
  166.         ++size[k];
  167.         distortion += err[i];
  168.     }
  169.     return distortion;
  170. }
  171.  
  172. /* Recomputes centroids.  Updates `cntr`. */
  173. static void recalculate_centroids()
  174. {
  175.     int k, d, i, empty;
  176.  
  177.     empty = 0;
  178.     for (k = 0; k < K; ++k)
  179.     {
  180.         if (size[k] > 0)
  181.         {
  182.             /* Calculate centroid as mean of coordinates of nearest points: */
  183.             for (d = 0; d < D; ++d)
  184.             {
  185.                 cntr[k][d] = (sum[k][d] + size[k]/2) / size[k];
  186.             }
  187.         }
  188.         else
  189.         {
  190.             /* Reassign at random. */
  191.             randomize_centroid(k);
  192.             ++empty;
  193.         }
  194.     }
  195.  
  196.     if (empty > 0)
  197.     {
  198.         printf("%d empty clusters reassigned.\n", empty);
  199.     }
  200. }
  201.  
  202. static void save_quantified_data()
  203. {
  204.     int i, d;
  205.  
  206.     for (i = 0; i < N; ++i)
  207.     {
  208.         for (d = 0; d < D; ++d)
  209.         {
  210.             best[i][d] = cntr[idx[i]][d];
  211.         }
  212.     }
  213. }
  214.  
  215. int main(int argc, char *argv[])
  216. {
  217.     int npass, niter, pass, iter;
  218.     long long distortion, last_distortion, min_distortion;
  219.  
  220.     K     = argc > 2 ? atoi(argv[2]) : -1;
  221.     npass = argc > 3 ? atoi(argv[3]) : -1;
  222.     niter = argc > 4 ? atoi(argv[4]) : -1;
  223.  
  224.     if (argc != 6 || K <= 0 || npass <= 0 || niter <= 0)
  225.     {
  226.         printf(
  227.             "Usage:\n"
  228.             "  %s <input> <K> <#pass> <#iter> <output>\n\n"
  229.             "Where:\n"
  230.             "  <input>   path to input file in PPM (P6) format\n"
  231.             "  <K>       number of clusters (colors) to create (e.g. 256)\n"
  232.             "  <#pass>   number of independent passes to run   (e.g. 10)\n"
  233.             "  <#iter>   maximum number of iterations per pass (e.g. 50)\n"
  234.             "  <output>  path to output file\n"
  235.             "\n", argv[0]);
  236.         return 0;
  237.     }
  238.  
  239.     read_input(argv[1]);
  240.     srand(time(NULL));
  241.     for (pass = 0; pass < npass; ++pass)
  242.     {
  243.         randomize_centroids();
  244.         for (iter = 0; iter < niter; ++iter)
  245.         {
  246.             if (iter > 0) recalculate_centroids();
  247.             cluster_points();
  248.             distortion = calculate_distortion();
  249.             printf( "pass=%d/%d iter=%d/%d distortion=%lf\n",
  250.                      pass + 1, npass, iter + 1, niter, (double)distortion/N );
  251.             if (iter > 0 && distortion >= last_distortion) break;
  252.             last_distortion = distortion;
  253.         }
  254.         if (pass == 0 || distortion < min_distortion)
  255.         {
  256.             min_distortion = distortion;
  257.             save_quantified_data();
  258.         }
  259.     }
  260.     write_output(argv[5]);
  261.     return 0;
  262. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement