Advertisement
Guest User

Untitled

a guest
Jun 25th, 2017
53
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 10.35 KB | None | 0 0
  1. // simple_backprop.cpp : Defines the entry point for the console application.
  2. //
  3. #include <cmath>
  4. #include <iostream>
  5.  
  6. #include <fstream>
  7.  
  8. #include <iomanip>
  9. #include <random>
  10. #include <algorithm>
  11. #include <cstdint>
  12.  
  13. #include <arpa/inet.h> // for ntohl
  14.  
  15. using namespace std;
  16.  
  17. typedef float flt;
  18.  
  19. //#pragma omp
  20.  
  21. /*
  22. inline void activation_function(flt a, flt &result, flt &der) {
  23.     result=a>-20? 1.0 / (1.0 + exp(-a)) : 0.0;
  24.     der= result*(flt(1.0) - result);
  25. }*/
  26.  
  27.  
  28. inline void activation_function(flt a, flt &result, flt &der) {// der = dresult/da
  29.     const float param=0.1;
  30.     result = a>0.0 ? a : a*param;
  31.     der = a>0.0 ? 1.0 : param;
  32. }
  33.  
  34. flt loss_function(flt a, flt b) {
  35.     return flt(0.5)*(a-b)*(a-b);
  36. }
  37.  
  38. flt loss_function_der(flt a, flt b) {// d loss_function(a,b)/db
  39.     return b - a;
  40. }
  41.  
  42. const flt step_grow=1.2f;
  43. const flt step_shrink=0.5;
  44.  
  45. const flt step_min=1E-10;
  46. const flt step_max= 10.0;
  47.  
  48. template<class T>
  49. T clamp(T a, T low, T high){
  50.     return std::min(std::max(a,low),high);
  51. }
  52.  
  53. std::random_device rdev;
  54. mt19937 rng(rdev());
  55. std::uniform_real_distribution<> r(0, 1);
  56.  
  57. template<int n_inputs_, int n_outputs_>
  58. struct layer {
  59.     static const int n_inputs=n_inputs_;
  60.     static const int n_outputs=n_outputs_;
  61.     flt input_values[n_inputs]{};
  62.  
  63.     flt weights[n_outputs][n_inputs]{};
  64.     flt old_weights[n_outputs][n_inputs]{};
  65.  
  66.     flt derr_per_weight_sums[n_outputs][n_inputs]{};
  67.     flt old_derr_per_weight_sums[n_outputs][n_inputs]{};
  68.  
  69.     flt weight_step_size[n_outputs][n_inputs]{};
  70.  
  71.     flt input_derivatives[n_inputs]{};
  72.     flt output_values[n_outputs]{};
  73.     flt dout_per_dsum[n_outputs]{};
  74.     flt derr_per_dout[n_outputs]{};
  75.  
  76.     layer() {// initialize weights to sum to 1 for now
  77.  
  78.  
  79.         for (int o = 0; o < n_outputs; ++o)
  80.             for (int i = 0; i < n_inputs; ++i){
  81.                 weights[o][i] = old_weights[o][i] = r(rng)*0.1-0.05;//0.1*(r(rng)*0.2+0.9)/n_inputs;
  82.                 weight_step_size[o][i]=0.01;
  83.             }
  84.     }
  85.  
  86.     void forward(const flt inputs[n_inputs]) {
  87.         //#pragma omp parallel_for
  88.         for (int i = 0; i < n_inputs; ++i) {
  89.             input_values[i] = inputs[i];
  90.         }
  91.         //#pragma omp parallel_for
  92.         for (int o = 0; o < n_outputs; ++o) {
  93.             flt sum = 0;
  94.             for (int i = 0; i < n_inputs; ++i) {
  95.                 sum += inputs[i]*weights[o][i];
  96.             }
  97.             activation_function(sum, output_values[o], dout_per_dsum[o]);
  98.         }
  99.     }
  100.     // must do forward step and initialize derr_per_dout before calling this
  101.     void backward (flt learn_rate, flt derr_per_din[n_inputs]) {
  102.         // zero the input (we'll sum inplace)
  103.         //#pragma omp parallel_for
  104.         for (int i = 0; i < n_inputs; ++i) {
  105.             derr_per_din[i] = flt(0);
  106.         }
  107.  
  108.         //float derr_per_dsum[n_inputs];
  109.         // I should turn this loop inside out for parallelization
  110.         for (int o = 0; o < n_outputs; ++o) {
  111.             flt derr_per_dsum = derr_per_dout[o] * dout_per_dsum[o]; // d f(g(x)) / dx = f'(g(x)) * g'(x)
  112.  
  113.             //#pragma omp parallel for
  114.             for (int i = 0; i < n_inputs; ++i) {
  115.                 derr_per_din[i] += weights[o][i] * derr_per_dsum; // dsum/din = weights[o][i]
  116.  
  117.                 flt derr_per_weight = derr_per_dsum * input_values[i]; // dsum/dweight = input_values[i]
  118.  
  119.                 weights[o][i] -= learn_rate * derr_per_weight ;
  120.                 //derr_per_weight_sums[o][i] += derr_per_weight ;
  121.             }
  122.         }
  123.     }
  124.  
  125.     void backward_from_target(flt learn_rate, flt derr_per_din[n_inputs], const flt desired_outputs[n_inputs]){
  126.         //#pragma omp parallel_for
  127.         for(int o=0; o<n_outputs; ++o){
  128.             derr_per_dout[o] = loss_function_der(desired_outputs[o], output_values[o]);
  129.         }
  130.         backward(learn_rate, derr_per_din);
  131.     }
  132.  
  133.     // for Rprop
  134.     /*
  135.     void update_weights(){
  136.         return;
  137.         for (int o = 0; o < n_outputs; ++o)
  138.             for (int i = 0; i < n_inputs; ++i){
  139.                 if((old_derr_per_weight_sums[o][i]>0) != (derr_per_weight_sums[o][i]>0)){// overshot, weight rollback
  140.                     weight_step_size[o][i]*=step_shrink;
  141.                     weights[o][i]=old_weights[o][i];
  142.                 }else{
  143.                     old_weights[o][i]=weights[o][i];
  144.                     weight_step_size[o][i]*=step_grow;
  145.                     weight_step_size[o][i]=clamp(weight_step_size[o][i], step_min, step_max);
  146.  
  147.                     weights[o][i]-=derr_per_weight_sums[o][i]>0 ? weight_step_size[o][i] : -weight_step_size[o][i];
  148.  
  149.                     //weights[o][i] -= derr_per_weight_sums[o][i]*weight_step_size[o][i];
  150.  
  151.                     weights[o][i]=clamp(weights[o][i], flt(-4), flt(4));
  152.                 }
  153.                 old_derr_per_weight_sums[o][i]=derr_per_weight_sums[o][i];
  154.                 derr_per_weight_sums[o][i]=0;
  155.             }
  156.     }*/
  157.  
  158.     void print() {
  159.         for (int o = 0; o < n_outputs; ++o) {
  160.             for (int i = 0; i < n_inputs; ++i) {
  161.                 std::cout << weights[o][i] << "\t";
  162.             }
  163.             std::cout << std::endl;
  164.         }
  165.     }
  166. };
  167.  
  168. layer<28*28, 600> input_layer;
  169. layer<600,300> hidden_layer;
  170. layer<300,10> output_layer;
  171.  
  172. void forward(const flt inputs[input_layer.n_inputs]) {
  173.     input_layer.forward(inputs);
  174.     hidden_layer.forward(input_layer.output_values);
  175.     output_layer.forward(hidden_layer.output_values);
  176. }
  177. float loss(const flt desired_outputs[output_layer.n_outputs]) {
  178.     flt error_sum{};
  179.     for (int i = 0; i < output_layer.n_outputs; ++i) {
  180.         error_sum += loss_function(desired_outputs[i], output_layer.output_values[i]);
  181.     }
  182.     return error_sum;
  183. }
  184.  
  185. void backward(flt learn_rate, const flt desired_outputs[output_layer.n_outputs]) {
  186.     output_layer.backward_from_target(learn_rate, hidden_layer.derr_per_dout, desired_outputs);
  187.     hidden_layer.backward(learn_rate, input_layer.derr_per_dout);
  188.     flt tmp[input_layer.n_inputs];
  189.     input_layer.backward(learn_rate, tmp); // error deltas of the input are unused (todo: use them for visualization?)
  190. }
  191.  
  192. /*
  193. void update_weights(){
  194.     for(auto &a: layers){
  195.         a.update_weights();
  196.     }
  197. }*/
  198.  
  199. struct hand_written_digits_dataset{
  200.     std::vector<flt> data;
  201.     std::vector<uint8_t> labels;
  202.     uint32_t data_width{}, data_height{}, count{};
  203.     flt *get_image(size_t i){
  204.         return data.data()+i*data_width*data_height;
  205.     }
  206. };
  207.  
  208. hand_written_digits_dataset train_set;
  209. hand_written_digits_dataset test_set;
  210.  
  211. void load_data(hand_written_digits_dataset &result, const char *images_filename, const char *labels_filename){
  212.     struct {
  213.         uint32_t magic;
  214.         uint32_t count;
  215.         uint32_t width;
  216.         uint32_t height;
  217.     } images_header{};
  218.     struct {
  219.         uint32_t magic;
  220.         uint32_t count;
  221.     } labels_header{};
  222.     std::cout<<"Reading images file "<<images_filename<<" and labels "<<labels_filename<<std::endl;
  223.     std::ifstream f(images_filename, ios::binary|ios::in);
  224.     f.read((char *)&images_header, sizeof(images_header));
  225.     // convert
  226.     images_header.magic=ntohl(images_header.magic);
  227.     images_header.count=ntohl(images_header.count);
  228.     result.data_width=images_header.width=ntohl(images_header.width);
  229.     result.data_height=images_header.height=ntohl(images_header.height);
  230.  
  231.     std::ifstream fl(labels_filename, ios::binary|ios::in);
  232.     fl.read((char *)&labels_header, sizeof(labels_header));
  233.     labels_header.magic=ntohl(labels_header.magic);
  234.     labels_header.count=ntohl(labels_header.count);
  235.  
  236.     if(labels_header.count!=images_header.count){
  237.         std::cout<<"count mismatch "<<images_header.count<<" "<<labels_header.count<<std::endl;
  238.         return;
  239.     }
  240.     std::cout<<"loading "<<images_header.count<<" images"<<std::endl;
  241.  
  242.     std::vector<uint8_t> images_data(images_header.count*images_header.width*images_header.height);
  243.     f.read((char *)images_data.data(), images_data.size());
  244.     if(!f.good()){
  245.         std::cout<<"failed reading images"<<std::endl;
  246.         return;
  247.     }
  248.     result.labels.resize(labels_header.count);
  249.     fl.read((char *)result.labels.data(), result.labels.size());
  250.     if(!fl.good()){
  251.         std::cout<<"failed reading labels"<<std::endl;
  252.         return;
  253.     }
  254.     result.data.resize(images_data.size());
  255.     for(size_t i=0; i<images_data.size(); ++i){
  256.         result.data[i]=images_data[i]*(1.0f/255.0f);
  257.     }
  258.     result.count=images_header.count;
  259.  
  260.     std::cout<<"read successful"<<std::endl;
  261. }
  262.  
  263. int recognize(flt *inputs){
  264.     forward(inputs);
  265.     flt max_v=-1E10;
  266.     int max_i=0;
  267.     for(int i=0;i<10;++i){
  268.     flt v=output_layer.output_values[i];
  269.         if(v>max_v){
  270.             max_v=v;
  271.             max_i=i;
  272.         }
  273.     }
  274.     return max_i;
  275. }
  276.  
  277. void test(){
  278.     flt loss_sum = 0;
  279.     int errors = 0;
  280.     for (int example = 0;  example < test_set.count; ++example) {
  281.         errors+=recognize(test_set.get_image(example))!=test_set.labels[example];
  282.     }
  283.     std::cout<<"Error rate in test dataset: "<<(((float)errors)/test_set.count)<<std::endl;
  284. }
  285.  
  286. void train() {
  287.     for (int i = 0; i <= 1000000; ++i){
  288.         flt loss_sum = 0;
  289.         bool print = true;//!(i % 100);
  290.         int errors = 0;
  291.         for (int example = 0;  example < train_set.count; ++example) {
  292.             flt *inputs=train_set.get_image(example);
  293.  
  294.             flt desired_outputs[10]{};
  295.             int label=train_set.labels[example];
  296.             if(label<0 || label>=10){
  297.                 std::cout<<"invalid label"<<std::endl;
  298.             }
  299.             desired_outputs[label]=1;
  300.  
  301.             forward(inputs);
  302.             loss_sum += loss(desired_outputs);
  303.  
  304.             flt max_v=-1E10;
  305.             int max_i=0;
  306.             for(int i=0;i<10;++i){
  307.                 flt v=output_layer.output_values[i];
  308.                 if(v>max_v){
  309.                     max_v=v;
  310.                     max_i=i;
  311.                 }
  312.             }
  313.             errors+=max_i!=label;
  314.  
  315.             backward(0.01, desired_outputs);
  316.  
  317.             if(!(example % 1000))std::cout<<example<<std::endl;
  318.         }// example loop
  319.         if (print) {
  320.             std::cout << "epoch " << i <<" loss sum over all examples:" << std::fixed << std::setprecision(6) << loss_sum <<" errors:" << errors << std::endl;
  321.         }
  322.         if (!errors) {
  323.             std::cout << "Set fully recognized?!" << std::endl;
  324.             break;
  325.         }
  326.         if(!(i%2)){
  327.             test();
  328.         }
  329.     }
  330. }
  331.  
  332. const char *symbols=" .,-+coC0#";
  333.  
  334. void print_image(flt *image, int w, int h){
  335.  
  336.     for(int i=0;i<h;++i){
  337.         for(int j=0;j<w;++j){
  338.             float a=*(image++);
  339.             int av=clamp(int(a*10.0), 0, 9);
  340.             std::cout<<symbols[av];
  341.         }
  342.         std::cout<<std::endl;
  343.     }
  344. }
  345.  
  346.  
  347. int main()
  348. {
  349.     load_data(train_set, "../data/train-images.idx3-ubyte", "../data/train-labels.idx1-ubyte");
  350.     load_data(test_set, "../data/t10k-images.idx3-ubyte", "../data/t10k-labels.idx1-ubyte");
  351.     print_image(test_set.get_image(0), 28, 28);
  352.  
  353.     train();
  354.     test();
  355.     char c;
  356.     std::cin >> c;
  357.     return 0;
  358. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement