Guest User

Untitled

a guest
Jan 23rd, 2019
62
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 13.69 KB | None | 0 0
  1. #include <iostream>
  2. #include <cstdio>
  3. #include <cmath>
  4. #include <fstream>
  5. #include <string>
  6. #include <vector>
  7.  
  8. using namespace std;
  9.  
  10. ofstream logFile("last_log.txt");
  11.  
  12. double LEARNING_RATE = -1;
  13.  
  14. const double MOMENTUM_RATE = -1; // used to "nudge" weights away from local minima
  15.  
  16. int NUM_INPUT_NODES = -1;
  17. int NUM_HIDDEN_NODES = -1;
  18. int NUM_OUTPUT_NODES = -1; // must be one for now
  19.  
  20. int NUM_TRAINING_DATA_ROWS = -1;
  21. int NUM_EPOCHS = -1;
  22.  
  23. const int BIAS_NODE_COUNT = 1;
  24. const int BIAS_NODE_INDEX = 0;
  25.  
  26. const double BIAS_NODE_VALUE = -1.0f;
  27.  
  28. // these globals take bias nodes into account
  29. int REAL_NUM_INPUT_NODES = -1;
  30. int REAL_NUM_HIDDEN_NODES = -1;
  31. int REAL_NUM_OUTPUT_NODES = -1;
  32.  
  33. const double SIGMOID_CONSTANT = 1.0f;
  34.  
  35. vector< std::vector<double> > w0;
  36. vector< std::vector<double> > w1;
  37.  
  38. vector< std::vector<double> > w0_delta;
  39. vector< std::vector<double> > w1_delta;
  40.  
  41. // datasets
  42. vector< std::vector<double> > inputDataSet;
  43. vector< vector<double> > targetDataSet;
  44.  
  45. // layer outputs
  46. vector<double> input;
  47. vector<double> hidden;
  48. vector<double> output;
  49.  
  50. vector<double> target;
  51.  
  52. // layer errors
  53. vector<double> inputError;
  54. vector<double> hiddenError;
  55. vector<double> outputError;
  56.  
  57. // input min/max
  58. vector<double> inputMin;
  59. vector<double> inputMax;
  60.  
  61. // prototypes
  62. void backProp();
  63. void printState();
  64. double normInputEntry( int, double );
  65.  
  66.  
  67. double sigmoid(double x) {
  68. return 1.0/(1 + exp( -x/SIGMOID_CONSTANT));
  69. }
  70.  
  71. void initNetworkWeights() {
  72.  
  73. for ( int inputIndex = 0; inputIndex < REAL_NUM_INPUT_NODES; inputIndex++ ) {
  74.  
  75. for ( int hiddenIndex = 0; hiddenIndex < REAL_NUM_HIDDEN_NODES; hiddenIndex++ ) {
  76.  
  77. if(hiddenIndex == BIAS_NODE_INDEX) {
  78. w0[inputIndex][hiddenIndex] = 0.0;// make sure we preserve bias node of hidden layer
  79. }
  80. else {
  81. w0[inputIndex][hiddenIndex] = (rand()/(double)RAND_MAX * .6) - 0.3; // random (enough) number from -0.5 to 0.5
  82. }
  83. w0_delta[inputIndex][hiddenIndex] = 0.0;
  84. }
  85. }
  86.  
  87. for ( int hiddenIndex = 0; hiddenIndex < REAL_NUM_HIDDEN_NODES; hiddenIndex++ ) {
  88.  
  89. for ( int outputIndex = 0; outputIndex < REAL_NUM_OUTPUT_NODES; outputIndex++ ) {
  90. w1[hiddenIndex][outputIndex] = (rand()/(double)RAND_MAX * .6) - 0.3;//rand()/(double)RAND_MAX - 0.5f; // random (enough) number from -0.5 to 0.5
  91. w1_delta[hiddenIndex][outputIndex] = 0.0;
  92. }
  93. }
  94. }
  95.  
  96. void calcNetwork() {
  97.  
  98. for ( int hiddenNodeIndex = 1; hiddenNodeIndex < REAL_NUM_HIDDEN_NODES; hiddenNodeIndex++ ) {
  99. hidden[hiddenNodeIndex] = 0.0f;
  100.  
  101. for ( int inputNodeIndex = 0; inputNodeIndex < REAL_NUM_INPUT_NODES; inputNodeIndex++ ) {
  102. hidden[hiddenNodeIndex] += input[inputNodeIndex] * w0[inputNodeIndex][hiddenNodeIndex];
  103. }
  104. hidden[hiddenNodeIndex] = sigmoid(hidden[hiddenNodeIndex]);
  105. }
  106.  
  107. for ( int outputNodeIndex = 0; outputNodeIndex < REAL_NUM_OUTPUT_NODES; outputNodeIndex++ ) {
  108. output[outputNodeIndex] = 0;
  109.  
  110. for ( int hiddenNodeIndex = 0; hiddenNodeIndex < REAL_NUM_HIDDEN_NODES; hiddenNodeIndex++ ) {
  111. output[outputNodeIndex] += hidden[hiddenNodeIndex] * w1[hiddenNodeIndex][outputNodeIndex];
  112. }
  113. }
  114. }
  115.  
  116. void backProp() {
  117.  
  118. // calculate the output error
  119. for ( int outputNodeIndex = 0; outputNodeIndex < REAL_NUM_OUTPUT_NODES; outputNodeIndex++ ) {
  120. outputError[outputNodeIndex] = output[outputNodeIndex] *
  121. (1-output[outputNodeIndex]) *
  122. (target[outputNodeIndex]-output[outputNodeIndex]);
  123. }
  124.  
  125. // calculate the error for the hidden layer
  126. for ( int hiddenErrorIndex = 0; hiddenErrorIndex < REAL_NUM_HIDDEN_NODES; hiddenErrorIndex++ ) {
  127.  
  128. double errorSum = 0.0f;
  129.  
  130. for( int outputNodeIndex = 0; outputNodeIndex < REAL_NUM_OUTPUT_NODES; outputNodeIndex++ ) {
  131.  
  132. errorSum += w1[hiddenErrorIndex][outputNodeIndex] * outputError[outputNodeIndex];
  133. }
  134. hiddenError[hiddenErrorIndex] = hidden[hiddenErrorIndex] * (1.0f - hidden[hiddenErrorIndex]) * errorSum;
  135. }
  136.  
  137. // update weights going to the output layer
  138. for ( int hiddenNodeIndex = 0; hiddenNodeIndex < REAL_NUM_HIDDEN_NODES; hiddenNodeIndex++ ) {
  139. for ( int outputNodeIndex = 0; outputNodeIndex < REAL_NUM_OUTPUT_NODES; outputNodeIndex++ ) {
  140. w1[hiddenNodeIndex][outputNodeIndex] += LEARNING_RATE * outputError[outputNodeIndex] * hidden[hiddenNodeIndex];
  141. }
  142. }
  143.  
  144. // update weights going from the input to the output layer
  145. for ( int inputNodeIndex = 0; inputNodeIndex < REAL_NUM_INPUT_NODES; inputNodeIndex++ ) {
  146. for ( int hiddenNodeIndex = 1; hiddenNodeIndex < REAL_NUM_HIDDEN_NODES; hiddenNodeIndex++ ) {
  147.  
  148. w0[inputNodeIndex][hiddenNodeIndex] += LEARNING_RATE * hiddenError[hiddenNodeIndex] * input[inputNodeIndex];
  149. }
  150. }
  151. }
  152.  
  153. void printState() {
  154.  
  155. logFile << "#########################\n######## STATE: #########\n#########################\n\n";
  156.  
  157. logFile << "w0\n";
  158. for ( int x = 0; x <REAL_NUM_INPUT_NODES; x++ ) {
  159. for ( int y = 0; y <REAL_NUM_HIDDEN_NODES; y++ ) {
  160. logFile << w0[x][y] << " ";
  161. }
  162. logFile << endl;
  163. }
  164. logFile << endl;
  165.  
  166. logFile << "w1\n";
  167. for ( int x = 0; x <REAL_NUM_HIDDEN_NODES; x++ ) {
  168. for ( int y = 0; y <REAL_NUM_OUTPUT_NODES; y++ ) {
  169. logFile << w1[x][y] << " ";
  170. }
  171. logFile << endl;
  172. }
  173. logFile << endl;
  174.  
  175. logFile << "input\n";
  176. for ( int x = 0; x < REAL_NUM_INPUT_NODES; x++ ) {
  177.  
  178. if(x==0) logFile << "(b)";
  179. logFile << input[x] << " ";
  180. }
  181. logFile << endl;
  182. /*
  183. logFile << "inputDataSet\n";
  184. for ( int x = 0; x < NUM_TRAINING_DATA_ROWS; x++ ) {
  185. for ( int y = 0; y < NUM_INPUT_NODES; y++ ) {
  186. logFile << inputDataSet[x][y] << " ";
  187. }
  188. logFile << endl;
  189. }
  190. logFile << endl;
  191. */
  192. logFile << "inputMin\n";
  193. for ( int x = 0; x < NUM_INPUT_NODES; x++ ) {
  194.  
  195. if(x==0) logFile << "(b)";
  196. logFile << inputMin[x] << " ";
  197. }
  198. logFile << endl;
  199.  
  200. logFile << "inputMax\n";
  201. for ( int x = 0; x < NUM_INPUT_NODES; x++ ) {
  202.  
  203. if(x==0) logFile << "(b)";
  204. logFile << inputMax[x] << " ";
  205. }
  206. logFile << endl;
  207.  
  208. logFile << "hidden\n";
  209. for ( int x = 0; x < REAL_NUM_HIDDEN_NODES; x++ ) {
  210. logFile << hidden[x] << " ";
  211. }
  212. logFile << endl;
  213.  
  214. logFile << "output\n";
  215. for ( int x = 0; x < REAL_NUM_OUTPUT_NODES; x++ ) {
  216. logFile << output[x] << " ";
  217. }
  218. logFile << endl;
  219.  
  220. logFile << "target\n";
  221. for ( int x = 0; x < REAL_NUM_OUTPUT_NODES; x++ ) {
  222. logFile << target[x] << " ";
  223. }
  224. logFile << endl;
  225.  
  226.  
  227. logFile << "output error\n";
  228.  
  229. double errorSum = 0.0f;
  230. for ( int x = 0; x < REAL_NUM_OUTPUT_NODES; x++ ) {
  231. errorSum += target[x] - output[x];
  232. }
  233.  
  234. logFile << errorSum << "\n\n\n";
  235.  
  236. }
  237.  
  238. void prepareInput(std::string fileName) {
  239.  
  240. double max; // there will never be negative input (for now)
  241. double min;
  242.  
  243. ifstream dataFile(fileName.c_str());
  244.  
  245. if(!dataFile) {
  246. cerr << "File Open FAILED\n";
  247. exit(1);
  248. }
  249.  
  250. dataFile >> NUM_INPUT_NODES;
  251. dataFile >> NUM_HIDDEN_NODES;
  252. dataFile >> NUM_OUTPUT_NODES;
  253.  
  254. dataFile >> NUM_TRAINING_DATA_ROWS;
  255. dataFile >> NUM_EPOCHS;
  256.  
  257. dataFile >> LEARNING_RATE;
  258.  
  259. logFile << "\nNUM_INPUT_NODES: " << NUM_INPUT_NODES;
  260. logFile << "\nNUM_HIDDEN_NODES: " << NUM_HIDDEN_NODES;
  261. logFile << "\nNUM_OUTPUT_NODES: " << NUM_OUTPUT_NODES;
  262. logFile << "\nNUM_TRAINING_DATA_ROWS: " << NUM_TRAINING_DATA_ROWS;
  263. logFile << "\nNUM_EPOCHS: " << NUM_EPOCHS;
  264. logFile << "\nLEARNING_RATE: " << LEARNING_RATE;
  265. logFile << endl;
  266.  
  267. REAL_NUM_INPUT_NODES = NUM_INPUT_NODES + BIAS_NODE_COUNT;
  268. REAL_NUM_HIDDEN_NODES = NUM_HIDDEN_NODES + BIAS_NODE_COUNT;
  269. REAL_NUM_OUTPUT_NODES = NUM_OUTPUT_NODES;
  270.  
  271. // setup all of the memory we will need
  272. inputMin.resize(NUM_INPUT_NODES);
  273. inputMax.resize(NUM_INPUT_NODES);
  274.  
  275. input.resize(REAL_NUM_INPUT_NODES);
  276. hidden.resize(REAL_NUM_HIDDEN_NODES);
  277. output.resize(REAL_NUM_OUTPUT_NODES);
  278.  
  279. target.resize(REAL_NUM_OUTPUT_NODES);
  280.  
  281. inputError.resize(REAL_NUM_INPUT_NODES);
  282. hiddenError.resize(REAL_NUM_HIDDEN_NODES);
  283. outputError.resize(REAL_NUM_OUTPUT_NODES);
  284.  
  285. inputDataSet.resize(NUM_TRAINING_DATA_ROWS, vector<double>(NUM_INPUT_NODES));
  286.  
  287. w0.resize(REAL_NUM_INPUT_NODES, vector<double>(REAL_NUM_HIDDEN_NODES));
  288. w1.resize(REAL_NUM_HIDDEN_NODES, vector<double>(REAL_NUM_OUTPUT_NODES));
  289.  
  290. w0_delta.resize(REAL_NUM_INPUT_NODES, vector<double>(REAL_NUM_HIDDEN_NODES));
  291. w1_delta.resize(REAL_NUM_HIDDEN_NODES, vector<double>(REAL_NUM_OUTPUT_NODES));
  292.  
  293. targetDataSet.resize(NUM_TRAINING_DATA_ROWS, vector<double>(REAL_NUM_OUTPUT_NODES));
  294.  
  295. cout << "inputDataSet.size(): " << inputDataSet.size() << endl;
  296. for( int inputMinIndex = 0; inputMinIndex < NUM_INPUT_NODES; inputMinIndex++ ) {
  297. dataFile >> inputMin[inputMinIndex];
  298. }
  299. for( int inputMaxIndex = 0; inputMaxIndex < NUM_INPUT_NODES; inputMaxIndex++ ) {
  300. dataFile >> inputMax[inputMaxIndex];
  301. }
  302.  
  303. // fill in the input and target data
  304. for( int dataRowIndex = 0; dataRowIndex < NUM_TRAINING_DATA_ROWS; dataRowIndex++ ) {
  305. cout << "inputDataSet[" << dataRowIndex <<"].size(): " << inputDataSet[0].size() << endl;
  306. for( int inputNodeIndex = 0; inputNodeIndex < NUM_INPUT_NODES; inputNodeIndex++ ) {
  307. dataFile >> inputDataSet[dataRowIndex][inputNodeIndex];
  308. logFile << "i[" << dataRowIndex << "]["
  309. << inputNodeIndex << "]: " << inputDataSet[dataRowIndex][inputNodeIndex] << " ";
  310. // normalize data
  311. inputDataSet[dataRowIndex][inputNodeIndex] = normInputEntry(inputNodeIndex, inputDataSet[dataRowIndex][inputNodeIndex]);
  312. logFile << "Before assert: inputDataSet[" << dataRowIndex << "]["
  313. << inputNodeIndex << "] = " << inputDataSet[dataRowIndex][inputNodeIndex] << endl;
  314. assert(inputDataSet[dataRowIndex][inputNodeIndex] >= 0.0);
  315.  
  316. }
  317. for( int targetNodeIndex = 0; targetNodeIndex < REAL_NUM_OUTPUT_NODES; targetNodeIndex++ ) {
  318.  
  319. dataFile >> targetDataSet[dataRowIndex][targetNodeIndex];
  320. logFile << "t[" << dataRowIndex << "]["
  321. << targetNodeIndex << "]: " << targetDataSet[dataRowIndex][targetNodeIndex] << " ";
  322. }
  323. logFile << endl;
  324. }
  325. printState();
  326. logFile << "\nAfter File Read: \n";
  327. printState();
  328. dataFile.close();
  329. }
  330.  
  331. double normInputEntry(int inputIndex, double inputVal) {
  332.  
  333. double result = (inputVal - inputMin[inputIndex])/(inputMax[inputIndex] - inputMin[inputIndex]);
  334.  
  335. logFile << "Converted input " << inputVal
  336. << " into " << result << " (min: " << inputMin[inputIndex]
  337. << ", max: " << inputMax[inputIndex] << ", inputIndex: " << inputIndex << ")\n";
  338.  
  339. return result;
  340. }
  341.  
  342. int main() {
  343.  
  344. srand((unsigned)time(0));
  345.  
  346. prepareInput("iris_data.txt");
  347. initNetworkWeights();
  348.  
  349. for( int epochIndex = 0; epochIndex < NUM_EPOCHS; epochIndex++ ) {
  350.  
  351. for( int dataRowIndex = 0; dataRowIndex < NUM_TRAINING_DATA_ROWS; dataRowIndex++ ) {
  352. logFile << "Epoch " << epochIndex << " row "
  353. << dataRowIndex << " ("
  354. << (epochIndex*NUM_TRAINING_DATA_ROWS + dataRowIndex)/(NUM_EPOCHS*NUM_TRAINING_DATA_ROWS)
  355. << "%)\n";
  356. cout << "Epoch " << epochIndex << " row "
  357. << dataRowIndex << " ("
  358. << 100*(epochIndex*NUM_TRAINING_DATA_ROWS + dataRowIndex)/(NUM_EPOCHS*NUM_TRAINING_DATA_ROWS)
  359. << "%)\n";
  360.  
  361. // present the input
  362. input[0] = BIAS_NODE_VALUE;
  363. for( int inputNodeIndex = 1; inputNodeIndex < REAL_NUM_INPUT_NODES; inputNodeIndex++ ) {
  364. input[inputNodeIndex] = inputDataSet[dataRowIndex][inputNodeIndex-1];
  365. }
  366. // present the targets
  367. for( int targetNodeIndex = 0; targetNodeIndex < REAL_NUM_OUTPUT_NODES; targetNodeIndex++ ) {
  368. target[targetNodeIndex] = targetDataSet[dataRowIndex][targetNodeIndex];
  369. }
  370.  
  371. calcNetwork();
  372. backProp();
  373. printState();
  374.  
  375. if(outputError[0] != outputError[0]) {
  376. initNetworkWeights();
  377. epochIndex = 0;
  378. dataRowIndex = 0;
  379. cout << "\n\nQNAN FOUND during data row " << dataRowIndex << ". Resetting weights.\n";
  380. }
  381. }
  382. }
  383.  
  384. printState();
  385.  
  386. double num1 = -1;
  387.  
  388. do {
  389. input[0] = BIAS_NODE_VALUE;
  390. for( int inputNodeIndex = 1; inputNodeIndex < REAL_NUM_INPUT_NODES; inputNodeIndex++ ) {
  391. cout << "enter num ( < 0 to quit)" << inputNodeIndex << ":";
  392. cin >> input[inputNodeIndex];
  393.  
  394. if(input[inputNodeIndex] < 0)
  395. return 1;
  396. input[inputNodeIndex] = normInputEntry( inputNodeIndex-1, input[inputNodeIndex] );// MAX_INPUT_NUMBER;
  397. }
  398.  
  399. calcNetwork();
  400. printState();
  401. cout << "network response: ";
  402. for( int outputNodeIndex = 0; outputNodeIndex < REAL_NUM_OUTPUT_NODES; outputNodeIndex++ ) {
  403. cout << output[outputNodeIndex] << " ";
  404. }
  405. cout << endl;
  406.  
  407. } while ( 1 );
  408.  
  409. logFile.close();
  410. }
Add Comment
Please, Sign In to add comment