Guest User

quantify-kmeans.c

a guest
Feb 19th, 2013
207
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. #include <assert.h>
  2. #include <stdio.h>
  3. #include <stdlib.h>
  4. #include <string.h>
  5. #include <time.h>
  6.  
  7. #define D 3                       /* number of dimensions (3 for RGB) */
  8.  
  9. static int width, height;         /* original image width and height */
  10. static int N;                     /* number of samples (width x height) */
  11. static int K;                     /* number of clusters to create */
  12. static unsigned char (*data)[D];  /* N coordinates of each point */
  13. static unsigned char (*best)[D];  /* N best quantified coordinates (so far) */
  14. static int *idx;                  /* N index of nearest centroid per point */
  15. static int *err;                  /* N squared distance to nearest centroid */
  16. static unsigned char (*cntr)[D];  /* K coordinates of each centroid */
  17. static int *size;                 /* K number of points nearest per centroid */
  18. static long long (*sum)[D];       /* K sum of coords of points near centroid */
  19.  
  20. static void allocate()
  21. {
  22.     data = malloc(N*sizeof(*data));
  23.     assert(data != NULL);
  24.     best = malloc(N*sizeof(*best));
  25.     assert(best != NULL);
  26.     idx = malloc(N*sizeof(*idx));
  27.     assert(idx != NULL);
  28.     err = malloc(N*sizeof(*err));
  29.     assert(err != NULL);
  30.     cntr = malloc(K*sizeof(*cntr));
  31.     assert(cntr != NULL);
  32.     size = malloc(K*sizeof(*size));
  33.     assert(size != NULL);
  34.     sum = malloc(K*sizeof(*sum));
  35.     assert(sum != NULL);
  36. }
  37.  
  38. /* Read input.  Initializes `data'. */
  39. static void read_input(const char *path)
  40. {
  41.     int i, d, res, val;
  42.     FILE *fp;
  43.  
  44.     fp = fopen(path, "rb");
  45.     if (fp == NULL)
  46.     {
  47.         perror(path);
  48.         exit(EXIT_FAILURE);
  49.     }
  50.     assert(fp != NULL);
  51.     res = fscanf(fp, "P6 %d %d %d%*1[\n]", &width, &height, &val);
  52.     assert(res == 3);
  53.     assert(width > 0);
  54.     assert(height > 0);
  55.     assert(val > 0 && val < 256);
  56.     N = width*height;
  57.     allocate();
  58.     res = fread(data, 3, N, fp);
  59.     assert(res == N);
  60.     fclose(fp);
  61. }
  62.  
  63. static void write_output(const char *path)
  64. {
  65.     FILE *fp;
  66.  
  67.     fp = fopen(path, "wb");
  68.     if (fp == NULL)
  69.     {
  70.         perror(path);
  71.         exit(EXIT_FAILURE);
  72.     }
  73.     fprintf(fp, "P6\n%d %d\n%d\n", width, height, 255);
  74.     fwrite(best, 3, N, fp);
  75.     fclose(fp);
  76. }
  77.  
  78. static void randomize_centroid(int k)
  79. {
  80. #if 0
  81.     int i, d;
  82.  
  83.     /* select random data point */
  84.     assert(RAND_MAX >= N);
  85.     i = rand()%N;
  86.     for (d = 0; d < D; ++d)
  87.     {
  88.         cntr[k][d] = data[i][d];
  89.     }
  90.  
  91. #else
  92.     int d;
  93.  
  94.     /* generate uniform random point in domain */
  95.     assert(RAND_MAX >= 255);
  96.     for (d = 0; d < D; ++d)
  97.     {
  98.         cntr[k][d] = rand()%256;
  99.     }
  100. #endif
  101. }
  102.  
  103. static void randomize_centroids()
  104. {
  105.     int k;
  106.  
  107.     for (k = 0; k < K; ++k)
  108.     {
  109.         randomize_centroid(k);
  110.     }
  111. }
  112.  
  113. /* Assigns points to nearest centroids.  Updates `idx' and `err'. */
  114. static long long cluster_points()
  115. {
  116.     int i, k, d, e, f, best_k, best_e;
  117.  
  118.     #pragma omp parallel for private(i, k, d, e, f, best_k, best_e)
  119.     for (i = 0; i < N; ++i)
  120.     {
  121.         best_e = 256*256*D;
  122.         best_k = -1;
  123.         for (k = 0; k < K; ++k)
  124.         {
  125.             e = 0;
  126.             for (d = 0; d < D; ++d)
  127.             {
  128.                 f = data[i][d] - cntr[k][d];
  129.                 e += f*f;
  130.             }
  131.             if (e < best_e)
  132.             {
  133.                 best_k = k;
  134.                 best_e = e;
  135.             }
  136.         }
  137.         assert(best_k >= 0);
  138.         idx[i] = best_k;
  139.         err[i] = best_e;
  140.     }
  141. }
  142.  
  143. /* Calculates per-cluster distortion.  Updates `size` and `sum'.
  144.    Returns total distortion. */
  145. static long long calculate_distortion()
  146. {
  147.     long long distortion;
  148.     int i, k, d;
  149.  
  150.     distortion = 0;
  151.     for (k = 0; k < K; ++k)
  152.     {
  153.         size[k] = 0;
  154.         for (d = 0; d < D; ++d)
  155.         {
  156.             sum[k][d] = 0;
  157.         }
  158.     }
  159.     for (i = 0; i < N; ++i)
  160.     {
  161.         k = idx[i];
  162.         for (d = 0; d < D; ++d)
  163.         {
  164.             sum[k][d] += data[i][d];
  165.         }
  166.         ++size[k];
  167.         distortion += err[i];
  168.     }
  169.     return distortion;
  170. }
  171.  
  172. /* Recomputes centroids.  Updates `cntr`. */
  173. static void recalculate_centroids()
  174. {
  175.     int k, d, i, empty;
  176.  
  177.     empty = 0;
  178.     for (k = 0; k < K; ++k)
  179.     {
  180.         if (size[k] > 0)
  181.         {
  182.             /* Calculate centroid as mean of coordinates of nearest points: */
  183.             for (d = 0; d < D; ++d)
  184.             {
  185.                 cntr[k][d] = (sum[k][d] + size[k]/2) / size[k];
  186.             }
  187.         }
  188.         else
  189.         {
  190.             /* Reassign at random. */
  191.             randomize_centroid(k);
  192.             ++empty;
  193.         }
  194.     }
  195.  
  196.     if (empty > 0)
  197.     {
  198.         printf("%d empty clusters reassigned.\n", empty);
  199.     }
  200. }
  201.  
  202. static void save_quantified_data()
  203. {
  204.     int i, d;
  205.  
  206.     for (i = 0; i < N; ++i)
  207.     {
  208.         for (d = 0; d < D; ++d)
  209.         {
  210.             best[i][d] = cntr[idx[i]][d];
  211.         }
  212.     }
  213. }
  214.  
  215. int main(int argc, char *argv[])
  216. {
  217.     int npass, niter, pass, iter;
  218.     long long distortion, last_distortion, min_distortion;
  219.  
  220.     K     = argc > 2 ? atoi(argv[2]) : -1;
  221.     npass = argc > 3 ? atoi(argv[3]) : -1;
  222.     niter = argc > 4 ? atoi(argv[4]) : -1;
  223.  
  224.     if (argc != 6 || K <= 0 || npass <= 0 || niter <= 0)
  225.     {
  226.         printf(
  227.             "Usage:\n"
  228.             "  %s <input> <K> <#pass> <#iter> <output>\n\n"
  229.             "Where:\n"
  230.             "  <input>   path to input file in PPM (P6) format\n"
  231.             "  <K>       number of clusters (colors) to create (e.g. 256)\n"
  232.             "  <#pass>   number of independent passes to run   (e.g. 10)\n"
  233.             "  <#iter>   maximum number of iterations per pass (e.g. 50)\n"
  234.             "  <output>  path to output file\n"
  235.             "\n", argv[0]);
  236.         return 0;
  237.     }
  238.  
  239.     read_input(argv[1]);
  240.     srand(time(NULL));
  241.     for (pass = 0; pass < npass; ++pass)
  242.     {
  243.         randomize_centroids();
  244.         for (iter = 0; iter < niter; ++iter)
  245.         {
  246.             if (iter > 0) recalculate_centroids();
  247.             cluster_points();
  248.             distortion = calculate_distortion();
  249.             printf( "pass=%d/%d iter=%d/%d distortion=%lf\n",
  250.                      pass + 1, npass, iter + 1, niter, (double)distortion/N );
  251.             if (iter > 0 && distortion >= last_distortion) break;
  252.             last_distortion = distortion;
  253.         }
  254.         if (pass == 0 || distortion < min_distortion)
  255.         {
  256.             min_distortion = distortion;
  257.             save_quantified_data();
  258.         }
  259.     }
  260.     write_output(argv[5]);
  261.     return 0;
  262. }
RAW Paste Data

Adblocker detected! Please consider disabling it...

We've detected AdBlock Plus or some other adblocking software preventing Pastebin.com from fully loading.

We don't have any obnoxious sound, or popup ads, we actively block these annoying types of ads!

Please add Pastebin.com to your ad blocker whitelist or disable your adblocking software.

×