Guest User

Untitled

a guest
Oct 17th, 2015
116
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 7.10 KB | None | 0 0
  1. using namespace mlpack;
  2. using namespace mlpack::ann;
  3.  
  4. BOOST_AUTO_TEST_SUITE(ConvolutionalNetworkTest);
  5.  
  6.  
  7. template<typename Net>
  8. void predict_digits(Net &net, arma::cube const &input,
  9.                     size_t begin, size_t end,
  10.                     size_t activate_node)
  11. {
  12.     arma::cube predict_data(input.n_rows, input.n_cols, 1);
  13.     arma::mat label;
  14.     double sum = 0;
  15.     for(size_t i = begin; i != end; ++i)
  16.     {
  17.         predict_data.slice(0) = input.slice(i);
  18.         net.Predict(predict_data, label);
  19.         if(label(activate_node) == 1){
  20.             ++sum;
  21.         }
  22.     }
  23.  
  24.     std::cout<<"predict accuracy of "<<activate_node<<" : "
  25.             <<(sum/(end - begin))<<"\n";
  26. }
  27.  
  28. /**
  29.  * Train and evaluate a vanilla network with the specified structure.
  30.  */
  31. template<
  32.         typename PerformanceFunction
  33.         >
  34. void BuildVanillaNetwork()
  35. {
  36.     arma::mat X;
  37.     X.load("mnist_first250_training_4s_and_9s.arm");
  38.  
  39.     // Normalize each point since these are images.
  40.     arma::uword nPoints = X.n_cols;
  41.     for (arma::uword i = 0; i < nPoints; i++)
  42.     {
  43.         X.col(i) /= norm(X.col(i), 2);
  44.     }
  45.  
  46.     // Build the target matrix.
  47.     arma::mat Y = arma::zeros<arma::mat>(10, nPoints);
  48.     for (size_t i = 0; i < nPoints; i++)
  49.     {
  50.         if (i < nPoints / 2)
  51.         {
  52.             Y.col(i)(5) = 1;
  53.         }
  54.         else
  55.         {
  56.             Y.col(i)(8) = 1;
  57.         }
  58.     }
  59.  
  60.     arma::cube input = arma::cube(28, 28, nPoints);
  61.     for (size_t i = 0; i < nPoints; i++)
  62.         input.slice(i) = arma::mat(X.colptr(i), 28, 28);
  63.  
  64.     /*
  65.    * Construct a convolutional neural network with a 28x28x1 input layer,
  66.    * 24x24x8 convolution layer, 12x12x8 pooling layer, 8x8x12 convolution layer
  67.    * and a 4x4x12 pooling layer which is fully connected with the output layer.
  68.    * The network structure looks like:
  69.    *
  70.    * Input    Convolution  Pooling      Convolution  Pooling      Output
  71.    * Layer    Layer        Layer        Layer        Layer        Layer
  72.    *
  73.    *          +---+        +---+        +---+        +---+
  74.    *          | +---+      | +---+      | +---+      | +---+
  75.    * +---+    | | +---+    | | +---+    | | +---+    | | +---+    +---+
  76.    * |   |    | | |   |    | | |   |    | | |   |    | | |   |    |   |
  77.    * |   +--> +-+ |   +--> +-+ |   +--> +-+ |   +--> +-+ |   +--> |   |
  78.    * |   |      +-+   |      +-+   |      +-+   |      +-+   |    |   |
  79.    * +---+        +---+        +---+        +---+        +---+    +---+
  80.    */
  81.  
  82.     ConvLayer<RMSPROP> convLayer0(1, 8, 5, 5);
  83.     BiasLayer2D<RMSPROP, ZeroInitialization> biasLayer0(8);
  84.     BaseLayer2D<PerformanceFunction> baseLayer0;
  85.     PoolingLayer<> poolingLayer0(2);
  86.  
  87.  
  88.  
  89.  
  90.     ConvLayer<RMSPROP> convLayer1(8, 12, 5, 5);
  91.     BiasLayer2D<RMSPROP, ZeroInitialization> biasLayer1(12);
  92.     BaseLayer2D<PerformanceFunction> baseLayer1;
  93.     PoolingLayer<> poolingLayer1(2);
  94.  
  95.     LinearMappingLayer<RMSPROP> linearLayer0(192, 10);
  96.     BiasLayer<RMSPROP> biasLayer2(10);
  97.     SoftmaxLayer<> softmaxLayer0;
  98.  
  99.     OneHotLayer outputLayer;
  100.  
  101.     auto modules = std::tie(convLayer0, biasLayer0, baseLayer0, poolingLayer0,
  102.                             convLayer1, biasLayer1, baseLayer1, poolingLayer1,
  103.                             linearLayer0, biasLayer2, softmaxLayer0);
  104.  
  105.     CNN<decltype(modules), decltype(outputLayer)>
  106.             net(modules, outputLayer);
  107.  
  108.     Trainer<decltype(net)> trainer(net, 50, 1, 0.7);
  109.     trainer.Train(input, Y, input, Y);
  110.  
  111.     predict_digits(net, input, 0, nPoints/2, 5);
  112.     predict_digits(net, input, nPoints/2, nPoints, 8);
  113.  
  114.     BOOST_REQUIRE_LE(trainer.ValidationError(), 0.7);
  115. }
  116.  
  117. /**
  118.  * Train the vanilla network on a larger dataset.
  119.  */
  120. BOOST_AUTO_TEST_CASE(VanillaNetworkTest)
  121. {
  122.     BuildVanillaNetwork<LogisticFunction>();
  123. }
  124.  
  125. /**
  126.  * Train and evaluate a vanilla network with the specified structure.
  127.  */
  128. template<
  129.         typename PerformanceFunction
  130.         >
  131. void BuildVanillaDropoutNetwork()
  132. {
  133.     arma::mat X;
  134.     X.load("mnist_first250_training_4s_and_9s.arm");
  135.  
  136.     // Normalize each point since these are images.
  137.     arma::uword nPoints = X.n_cols;
  138.     for (arma::uword i = 0; i < nPoints; i++)
  139.     {
  140.         X.col(i) /= norm(X.col(i), 2);
  141.     }
  142.  
  143.     // Build the target matrix.
  144.     arma::mat Y = arma::zeros<arma::mat>(10, nPoints);
  145.     for (size_t i = 0; i < nPoints; i++)
  146.     {
  147.         if (i < nPoints / 2)
  148.         {
  149.             Y.col(i)(5) = 1;
  150.         }
  151.         else
  152.         {
  153.             Y.col(i)(8) = 1;
  154.         }
  155.     }
  156.  
  157.     arma::cube input = arma::cube(28, 28, nPoints);
  158.     for (size_t i = 0; i < nPoints; i++)
  159.         input.slice(i) = arma::mat(X.colptr(i), 28, 28);
  160.  
  161.     /*
  162.    * Construct a convolutional neural network with a 28x28x1 input layer,
  163.    * 24x24x4 convolution layer, 24x24x4 dropout layer, 12x12x4 pooling layer,
  164.    * 8x8x8 convolution layer,8x8x8 Dropout Layer and a 4x4x12 pooling layer
  165.    * which is fully connected with the output layer. The network structure looks
  166.    * like:
  167.    *
  168.    * Input    Convolution  Dropout      Pooling     Convolution,     Output
  169.    * Layer    Layer        Layer        Layer       Dropout,         Layer
  170.    *                                                Pooling Layer
  171.    *          +---+        +---+        +---+
  172.    *          | +---+      | +---+      | +---+
  173.    * +---+    | | +---+    | | +---+    | | +---+                    +---+
  174.    * |   |    | | |   |    | | |   |    | | |   |                    |   |
  175.    * |   +--> +-+ |   +--> +-+ |   +--> +-+ |   +--> ............--> |   |
  176.    * |   |      +-+   |      +-+   |      +-+   |                    |   |
  177.    * +---+        +---+        +---+        +---+                    +---+
  178.    */
  179.  
  180.     ConvLayer<AdaDelta> convLayer0(1, 4, 5, 5);
  181.     BiasLayer2D<AdaDelta, ZeroInitialization> biasLayer0(4);
  182.     DropoutLayer2D<> dropoutLayer0;
  183.     BaseLayer2D<PerformanceFunction> baseLayer0;
  184.     PoolingLayer<> poolingLayer0(2);
  185.  
  186.     ConvLayer<AdaDelta> convLayer1(4, 8, 5, 5);
  187.     BiasLayer2D<AdaDelta, ZeroInitialization> biasLayer1(8);
  188.     DropoutLayer2D<> dropoutLayer1;
  189.     BaseLayer2D<PerformanceFunction> baseLayer1;
  190.     PoolingLayer<> poolingLayer1(2);
  191.  
  192.     LinearMappingLayer<AdaDelta> linearLayer0(128, 10);
  193.     BiasLayer<AdaDelta> biasLayer2(10);
  194.     SoftmaxLayer<> softmaxLayer0;
  195.  
  196.     OneHotLayer outputLayer;
  197.  
  198.     auto modules =
  199.             std::tie(convLayer0, biasLayer0, dropoutLayer0, baseLayer0, poolingLayer0,
  200.                      convLayer1, biasLayer1, dropoutLayer1, baseLayer1, poolingLayer1,
  201.                      linearLayer0, biasLayer2, softmaxLayer0);
  202.  
  203.     CNN<decltype(modules), decltype(outputLayer)>
  204.             net(modules, outputLayer);
  205.  
  206.     Trainer<decltype(net)> trainer(net, 50, 1, 0.7);
  207.     trainer.Train(input, Y, input, Y);
  208.  
  209.     predict_digits(net, input, 0, nPoints/2, 5);
  210.     predict_digits(net, input, nPoints/2, nPoints, 8);
  211.  
  212.     BOOST_REQUIRE_LE(trainer.ValidationError(), 0.7);
  213. }
  214.  
  215. /**
  216.  * Train the network on a larger dataset using dropout.
  217.  */
  218. BOOST_AUTO_TEST_CASE(VanillaNetworkDropoutTest)
  219. {
  220.     BuildVanillaDropoutNetwork<RectifierFunction>();
  221. }
  222.  
  223. BOOST_AUTO_TEST_SUITE_END();
Advertisement
Add Comment
Please, Sign In to add comment