Pella86

Node ifstream

Apr 26th, 2020
112
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 3.38 KB | None | 0 0
  1. //Node.h
  2.  
  3. #ifndef NODE_H
  4. #define NODE_H
  5.  
  6. #include <vector>
  7. #include <fstream>
  8. #include <iostream>
  9.  
  10. class Node
  11. {
  12.     public:
  13.         Node();
  14.         Node(std::vector<double> weights, double bias);
  15.         virtual ~Node();
  16.  
  17.         double output(std::vector<double> input);
  18.  
  19.         friend std::ostream& operator << (std::ostream& out, const Node& node);
  20.         friend std::ofstream& operator << (std::ofstream& out, const Node& node);
  21.         friend std::ifstream& operator >> (std::ifstream& in, Node& node);
  22.  
  23.     protected:
  24.  
  25.     private:
  26.  
  27.         std::vector<double> weights;
  28.         double bias;
  29.  
  30.         double z(std::vector<double> input);
  31.  
  32. };
  33.  
  34. void test_node();
  35.  
  36. #endif // NODE_H
  37.  
  38.  
  39. // ------------ Node.cpp -------------------------
  40. #include "Node.h"
  41.  
  42. #include <stdexcept>
  43. #include <string>
  44. #include <cmath>
  45.  
  46. double dot(std::vector<double> v1, std::vector<double> v2){
  47.     if(v1.size() != v2.size())
  48.     {
  49.             throw std::length_error(std::string("dot product: invalid vector lengths"));
  50.     }
  51.  
  52.     double dot_product;
  53.  
  54.     for(size_t i = 0; i < v1.size(); ++i){
  55.         dot_product += v1[i] * v2[i];
  56.     }
  57.  
  58.     return dot_product;
  59.  
  60. }
  61.  
  62. double sigmoid(double z){
  63.     if(z < 0){
  64.         double s = 1 - 1 / ( 1 + exp(z));
  65.         return s;
  66.     }
  67.     else{
  68.         double s = 1 / ( 1 + exp(-z));
  69.         return s;
  70.     }
  71. }
  72.  
  73.  
  74. Node::Node(){
  75. }
  76.  
  77. Node::Node(std::vector<double> iweights, double ibias) : weights(iweights), bias(ibias)
  78. {
  79. }
  80.  
  81. Node::~Node()
  82. {
  83.     //dtor
  84. }
  85.  
  86. double Node::z(std::vector<double> input)
  87. {
  88.     return dot(input, weights) + bias;
  89. }
  90.  
  91. double Node::output(std::vector<double> input)
  92. {
  93.     return sigmoid(z(input));
  94. }
  95.  
  96.  
  97. std::ostream& operator << (std::ostream& out, const Node& node)
  98. {
  99.     out << "Node [weights(" << node.weights.size() << ")";
  100.  
  101.     if(node.weights.size() > 0)
  102.     {
  103.         out << ": {";
  104.         for(auto w = node.weights.begin(); w != (node.weights.end() - 1); ++w)
  105.         {
  106.             out << *w << ", ";
  107.         }
  108.         out << * (node.weights.end() - 1) << "}";
  109.     }
  110.     out << " bias: " << node.bias << "]";
  111.  
  112.     return out;
  113. }
  114.  
  115. std::ofstream& operator << (std::ofstream& out, const Node& node){
  116.     out << node.bias;
  117.     out << (size_t) node.weights.size();
  118.     for(auto w : node.weights)
  119.     {
  120.         out << w;
  121.     }
  122.  
  123.     return out;
  124. }
  125.  
  126. std::ifstream& operator >> (std::ifstream& in, Node& node)
  127. {
  128.  
  129.     in >> node.bias;
  130.     std::cout << node.bias << std::endl;
  131.  
  132.     size_t sz;
  133.     in >> sz;
  134.  
  135.     std::vector<double> weights;
  136.     for(size_t i = 0; i < sz; ++i){
  137.         double w;
  138.         in >> w;
  139.         weights.push_back(w);
  140.     }
  141.     node.weights = weights;
  142.  
  143.     return in;
  144. }
  145.  
  146.  
  147. void test_node(){
  148.  
  149.     std::vector<double> v{1, 2, 3, 4};
  150.     std::cout << v.size() << std::endl;
  151.  
  152.     Node n(v, 10);
  153.  
  154.     std::cout << n << std::endl;
  155.     std::cout << "Writing to file..." << std::endl;
  156.  
  157.     std::ofstream node_file_out("./test/node_file.ndf", std::ios::out|std::ios::binary);
  158.  
  159.     if(node_file_out.is_open())
  160.     {
  161.         std::cout << "hello" << std::endl;
  162.         node_file_out << n;
  163.         node_file_out.flush();
  164.         node_file_out.close();
  165.     }
  166.  
  167.     Node n2;
  168.  
  169.     std::ifstream node_file_in("./test/node_file.ndf", std::ios::in|std::ios::binary);
  170.     node_file_in >> n2;
  171.     node_file_in.close();
  172.  
  173.     std::cout << n2 << std::endl;
  174.  
  175. }
Add Comment
Please, Sign In to add comment