Advertisement
Guest User

Untitled

a guest
Apr 9th, 2020
230
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C 5.97 KB | None | 0 0
  1. #define _CRT_SECURE_NO_WARNINGS
  2. #define __USE_MINGW_ANSI_STDIO 0
  3. #pragma comment(linker, "/STACK:256000000")
  4. #pragma GCC optimize("Ofast")
  5. #pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
  6.  
  7. #include <stdio.h>
  8. #include <math.h>
  9.  
  10. #define MIN(a,b) (((a)<(b))?(a):(b))
  11. #define MAX(a,b) (((a)>(b))?(a):(b))
  12.  
  13. #define LEARNING_RATE 0.1
  14. #define MAX_TRAIN_SET_SIZE 10000
  15. #define MAX_CATEGORIES_SIZE 1000 + 1
  16.  
  17. #define STEPS_COUNT 5000
  18.  
  19. /******************** Common data *******************/
  20. int n; // train sets size
  21. int m; // categories size
  22.  
  23. double x_sets[MAX_TRAIN_SET_SIZE][MAX_CATEGORIES_SIZE];
  24. double y_sets[MAX_TRAIN_SET_SIZE];
  25. double weights[MAX_CATEGORIES_SIZE];
  26.  
  27. void loadData();
  28.  
  29. /************* Packet settings **********************/
  30. #define PACKET_SIZE_FRACTION 0.1 // 10%
  31.  
  32. int packet_start_index;
  33. int packet_end_index;
  34. int packet_size;
  35.  
  36. void initializePacketSettings() {
  37.     //packet_size = 1;
  38.     //if (n < 100) {
  39.     //    packet_size = n;
  40.     //}
  41.     //else {
  42.     //    packet_size = n * PACKET_SIZE_FRACTION;
  43.     //}
  44.     if (n < 20) {
  45.         packet_size = n;
  46.     }
  47.     else {
  48.         packet_size = 20;
  49.     }
  50.     packet_start_index = 0;
  51.     packet_end_index = packet_size;
  52. }
  53.  
  54. /***** Normalization and denoramlization **********/
  55. double mean_for_category[MAX_CATEGORIES_SIZE];
  56. double mean_for_y;
  57.  
  58. double standard_deviation_for_category[MAX_CATEGORIES_SIZE];
  59. double standard_deviation_for_y;
  60.  
  61. void normalize();
  62. void denormalize();
  63. void calculateMeans();
  64. void calculateStandardDeviations();
  65.  
  66. /******** Gradient descent **************/
  67. double grad_for_sets[MAX_TRAIN_SET_SIZE][MAX_CATEGORIES_SIZE];
  68. double grad[MAX_CATEGORIES_SIZE];
  69.  
  70. void train();
  71. void updateGrad();
  72. void updateGradForSet(int set_number);
  73. double predict(double* x);
  74. void updateWeights();
  75. void moveToTheNextPacket();
  76.  
  77. int main() {
  78.     loadData();
  79.     initializePacketSettings();
  80.     normalize();
  81.     train();
  82.     denormalize();
  83.  
  84.     for (int i = 0; i < m; i++) {
  85.         printf("%lf\n", weights[i]);
  86.     }
  87.  
  88.     return 0;
  89. }
  90.  
  91. // Read data
  92. void loadData() {
  93.     scanf("%d%d", &n, &m);
  94.     m++;
  95.  
  96.     for (int i = 0; i < n; i++) {
  97.         for (int j = 0; j <= m; j++) {
  98.             if (j == m - 1) {
  99.                 x_sets[i][m] = 1;
  100.                 continue;
  101.             }
  102.  
  103.             if (j == m) {
  104.                 scanf("%lf", &y_sets[i]);
  105.                 continue;
  106.             }
  107.  
  108.             scanf("%lf", &x_sets[i][j]);
  109.         }
  110.     }
  111. }
  112.  
  113. // Normalization and denoramlization
  114. void normalize() {
  115.     calculateMeans();
  116.     calculateStandardDeviations();
  117.     for (int i = 0; i < n; i++) {
  118.         for (int j = 0; j < m; j++) {
  119.             double x = x_sets[i][j];
  120.             double mean = mean_for_category[j];
  121.             double standard_deviation = standard_deviation_for_category[j];
  122.             /* if standard deviation is zero, it means that all 'x' have the same
  123.             values, therefore we do not normalize them. Do not forget this point when
  124.             you are using denormalization algorithm */
  125.             if (standard_deviation != 0) {
  126.                 x_sets[i][j] = (x - mean) / standard_deviation;
  127.             }
  128.         }
  129.         y_sets[i] = (y_sets[i] - mean_for_y) / standard_deviation_for_y;
  130.     }
  131. }
  132.  
  133. void calculateMeans() {
  134.     for (int i = 0; i < n; i++) {
  135.         for (int j = 0; j < m; j++) {
  136.             mean_for_category[j] += x_sets[i][j];
  137.         }
  138.         mean_for_y += y_sets[i];
  139.     }
  140.     mean_for_y /= n;
  141.     for (int j = 0; j < m; j++) {
  142.         mean_for_category[j] /= n;
  143.     }
  144. }
  145.  
  146. void calculateStandardDeviations() {
  147.     for (int i = 0; i < n; i++) {
  148.         for (int j = 0; j < m; j++) {
  149.             double x = x_sets[i][j];
  150.             double mean = mean_for_category[j];
  151.             standard_deviation_for_category[j] += (x - mean) * (x - mean);
  152.         }
  153.         standard_deviation_for_y += (y_sets[i] - mean_for_y) * (y_sets[i] - mean_for_y);
  154.     }
  155.     standard_deviation_for_y /= n;
  156.     standard_deviation_for_y = sqrt(standard_deviation_for_y);
  157.     for (int j = 0; j < m; j++) {
  158.         standard_deviation_for_category[j] /= n;
  159.         standard_deviation_for_category[j] = sqrt(standard_deviation_for_category[j]);
  160.     }
  161. }
  162.  
  163. void denormalize() {
  164.     for (int i = 0; i < m - 1; i++) {
  165.         if (standard_deviation_for_category[i] != 0) {
  166.             double res = weights[i] * standard_deviation_for_y / standard_deviation_for_category[i];
  167.             weights[i] = res;
  168.             weights[m - 1] -= res * mean_for_category[i];
  169.         }
  170.     }
  171.     weights[m - 1] += mean_for_y;
  172. }
  173.  
  174. // Gradient descent
  175. void train() {
  176.     for (int i = 0; i < STEPS_COUNT; i++) {
  177.         updateGrad();
  178.         updateWeights();
  179.         moveToTheNextPacket();
  180.     }
  181. }
  182.  
  183. void moveToTheNextPacket() {
  184.     if (packet_end_index == n) {
  185.         packet_start_index = 0;
  186.         packet_end_index = packet_size;
  187.     }
  188.     else {
  189.         packet_start_index = packet_end_index;
  190.         packet_end_index = MIN(packet_end_index + packet_size, n);
  191.     }
  192. }
  193.  
  194. void updateGrad() {
  195.     for (int i = packet_start_index; i < packet_end_index; i++) {
  196.         updateGradForSet(i);
  197.         for (int j = 0; j < m; j++) {
  198.             if (i == 0) {
  199.                 grad[j] = grad_for_sets[0][j];
  200.             }
  201.             else {
  202.                 grad[j] += grad_for_sets[i][j];
  203.             }
  204.         }
  205.     }
  206. }
  207.  
  208. void updateGradForSet(int set_number) {
  209.     double* x = x_sets[set_number];
  210.     double y = y_sets[set_number];
  211.     double y_pred = predict(x);
  212.     double residual = (y_pred - y);
  213.     for (int i = 0; i < m; i++) {
  214.         double grad = residual * x[i];
  215.         grad_for_sets[set_number][i] = grad;
  216.     }
  217. }
  218.  
  219. double predict(double* x) {
  220.     double prediction = 0;
  221.     for (int i = 0; i < m; i++) {
  222.         prediction += weights[i] * x[i];
  223.     }
  224.     return prediction;
  225. }
  226.  
  227. void updateWeights() {
  228.     for (int i = 0; i < m; i++) {
  229.         weights[i] -= LEARNING_RATE * grad[i];
  230.     }
  231. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement