Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include "nn.h"
- static float sigmoid(float a) {
- return (0.5*a) / (1 + fabs(a)) + 0.5;
- }
- void nn_process(nn *net) {
- float hidden1[NN_NUMHIDDEN];
- float hidden2[NN_NUMHIDDEN];
- int w = 0;
- int i, h, o;
- /* hidden layer 1 */
- for (h = 0; h < NN_NUMHIDDEN; h++) {
- float activation = 0;
- for (i = 0; i < NN_NUMINPUTS; i++)
- activation += net->weight[w++] * net->input[i];
- activation += net->weight[w++];
- hidden1[h] = sigmoid(activation);
- }
- /* hidden layer 2 */
- for (h = 0; h < NN_NUMHIDDEN; h++) {
- float activation = 0;
- for (i = 0; i < NN_NUMHIDDEN; i++)
- activation += net->weight[w++] * hidden1[i];
- activation += net->weight[w++];
- hidden2[h] = sigmoid(activation);
- }
- /* output layer */
- for (o = 0; o < NN_NUMOUTPUTS; o++) {
- float activation = 0;
- for (h = 0; h < NN_NUMHIDDEN; h++)
- activation += net->weight[w++] * hidden2[h];
- activation += net->weight[w++];
- net->output[o] = sigmoid(activation);
- }
- }
- void nn_fillrandom(nn *net) {
- int w;
- for (w = 0; w < NN_NUMWEIGHTS; w++)
- net->weight[w] = RANDDBL(-8.0, 8.0);
- }
- void nn_mutate(nn *net) {
- if (DICE(10)) {
- /* high mutation */
- int w;
- for (w = 0; w < NN_NUMWEIGHTS; w++)
- if (DICE(200)) net->weight[w] += RANDDBL(-8, 8);
- }
- else {
- int w;
- for (w = 0; w < NN_NUMWEIGHTS; w++) {
- if (DICE(500)) net->weight[w] += RANDDBL(-1, 1);
- if (DICE(2000)) net->weight[w] += RANDDBL(-8, 8);
- }
- }
- /* clamp */
- int w;
- for (w = 0; w < NN_NUMWEIGHTS; w++)
- net->weight[w] = RETMAX(-10, RETMIN(10, net->weight[w]));
- }
- void nn_copy(nn *dest, nn *src) {
- int w;
- for (w = 0; w < NN_NUMWEIGHTS; w++)
- dest->weight[w] = src->weight[w];
- }
- void nn_print(nn *net) {
- int i, h, o;
- int w = 0;
- for (h = 0; h < NN_NUMHIDDEN; h++) {
- for (i = 0; i < NN_NUMINPUTS; i++)
- printf("%2.2f ", net->weight[w++]);
- printf("%2.2f\n", net->weight[w++]);
- }
- printf("\n\n");
- for (o = 0; o < NN_NUMOUTPUTS; o++) {
- for (h = 0; h < NN_NUMHIDDEN; h++)
- printf("%2.2f ", net->weight[w++]);
- printf("%2.2f\n", net->weight[w++]);
- }
- printf("\n\n");
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement