Advertisement
Guest User

Untitled

a guest
Sep 25th, 2021
41
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
HTML 19.52 KB | None | 0 0
  1. <html>
  2. <!DOCTYPE html>
  3.  
  4. <head>
  5.     <meta http-equiv="Content-Type" content="text/html; charset=utf-8" />
  6.  
  7.     <title>NN Experiment</title>
  8.  
  9.     <script src="https://d3js.org/d3.v6.min.js"></script>
  10.     <script src="math_lib.js"></script>
  11.     <script src="nn.js"></script>
  12.  
  13.     <style>
  14.         html,
  15.         body {
  16.             overflow: hidden;
  17.             width: 100%;
  18.             height: 100%;
  19.             margin: 0;
  20.             padding: 0;
  21.         }
  22.     </style>
  23. </head>
  24.  
  25. <body>
  26.     <script>let inputData = 0;</script>
  27.     <span>Layer: </span>
  28.     <button type="button" onclick="decrease()">&lt;</button>    
  29.     <span id="layerNumberSpan">0</span>
  30.     <button type="button" onclick="increase()">&gt;</button>    
  31.     <span id="layerNumberSpan">&nbsp;</span>
  32.     <input type="number" style="width: 80px;" id="countTrains" value="1" step="1000"></input>    
  33.     <button type="button" onclick="train()">Train</button>    
  34.     <span>Learning Speed: </span>
  35.     <input type="number" style="width: 80px;" id="learningSpeed" value="0.1" step="0.1" onchange="set_learning_speed()"></input>    
  36.     <div id="area" style='border: solid 1px red;'></div>
  37.     <script>
  38.         function increase()
  39.         {
  40.             inputData++;        
  41.             if (inputData >= nn.X.length)
  42.             {
  43.                 inputData = nn.X.length - 1;
  44.             }
  45.  
  46.             repopulate();
  47.         }
  48.  
  49.         function decrease()
  50.         {
  51.             inputData--;
  52.             if (inputData < 0)
  53.            {
  54.                inputData = 0;
  55.            }
  56.  
  57.            repopulate();
  58.        }
  59.  
  60.        function repopulate()
  61.        {
  62.            document.getElementById("layerNumberSpan").innerText = inputData;
  63.            runner.show();
  64.        }
  65.  
  66.        function train()
  67.        {
  68.            nn.train(Number(document.getElementById("countTrains").value));
  69.            runner.bind();            
  70.            repopulate();
  71.        }
  72.  
  73.        function set_learning_speed()
  74.        {
  75.            nn.alpha = Number(document.getElementById("learningSpeed").value);
  76.        }
  77.  
  78.        const svgWidth = 1600;
  79.        const svgHeight = 1024;
  80.  
  81.        const radius = 40;
  82.        const radiusWeight = 20;
  83.  
  84.        const LayerTypes = {
  85.            INPUT: 'input',
  86.            HIDDEN: 'hidden',
  87.            OUTPUT: 'output'
  88.        }
  89.  
  90.        class Layer {
  91.            
  92.            constructor(count) {
  93.                this.nodeCounts = count;
  94.            }
  95.  
  96.            nodeCoords(layerIndex, index) {
  97.                return [layerIndex * 25 + 5, index * 30 + 20];
  98.            }
  99.  
  100.            createGraph(runner) {
  101.                const svg = runner.svg;
  102.                const xscale = runner.x;
  103.                const yscale = runner.y;
  104.  
  105.                const weights = this.weights;
  106.                const deltaWeights = this.deltaWeights;
  107.                const input = this.input[inputData];
  108.                const inputAfterActivation = this.inputAfterActivation[inputData];
  109.                const output = this.type != LayerTypes.INPUT ? this.output[inputData] : undefined;
  110.                const expectingVector = this.type == LayerTypes.OUTPUT ? this.expecting[inputData] : undefined;
  111.                const error = this.type == LayerTypes.OUTPUT ? this.error[inputData] : undefined;
  112.  
  113.                range(this.nodeCounts).forEach((_, i) => {
  114.  
  115.                     const [xn, yn] = this.nodeCoords(this.layerIndex, i);
  116.                     const cx = xscale(xn);
  117.                     const cy = yscale(yn);
  118.  
  119.                     const cx2 = cx + xscale(2);
  120.                     const cy2 = cy + yscale(-4);
  121.  
  122.                     // total weight
  123.                     svg.append("circle")
  124.                         .attr("cx", cx)
  125.                         .attr("cy", cy)
  126.                         .attr("r", radius)
  127.                         .attr("stroke", "black")
  128.                         .attr("stroke-width", "3")
  129.                         .attr("fill", "green");
  130.  
  131.                     svg.append("text")
  132.                         .attr("x", cx)
  133.                         .attr("y", cy)
  134.                         .attr("fill", "white")
  135.                         .attr("text-anchor", "middle")
  136.                         .attr("dominant-baseline", "middle")
  137.                         .attr("font-size", "2em")
  138.                         .text(input[i].toFixed(2));
  139.  
  140.                     // output
  141.                     svg.append("circle")
  142.                         .attr("cx", cx2)
  143.                         .attr("cy", cy2)
  144.                         .attr("r", radius - 20)
  145.                         .attr("stroke", "black")
  146.                         .attr("stroke-width", "3")
  147.                         .attr("fill", "white");
  148.  
  149.                     svg.append("text")
  150.                         .attr("x", cx2)
  151.                         .attr("y", cy2)
  152.                         .attr("fill", "black")
  153.                         .attr("text-anchor", "middle")
  154.                         .attr("dominant-baseline", "middle")
  155.                         .attr("font-size", "1em")
  156.                         .text(inputAfterActivation[i].toFixed(2));
  157.  
  158.                     // draw lines
  159.                     if (this.layerIndex > 0) {
  160.                         const inputLayerIndex = this.layerIndex - 1;
  161.                         const inputCount = weights.length;
  162.  
  163.                         range(this.nodeCounts)
  164.                             .forEach((_, nodeIndex) => range(inputCount)
  165.                                 .forEach((_, inputNodeIndex) => {
  166.  
  167.                                     const [x1, y1] = this.nodeCoords(this.layerIndex, nodeIndex);
  168.                                     const [x2, y2] = this.nodeCoords(inputLayerIndex, inputNodeIndex);
  169.  
  170.                                     const x1abs = xscale(x1) - radius;
  171.                                     const y1abs = yscale(y1);
  172.                                     const x2abs = xscale(x2) + radius;
  173.                                     const y2abs = yscale(y2);
  174.  
  175.                                     //const xstep = (x2abs - x1abs) / (this.nodeCounts + 1);
  176.                                     //const ystep = (y2abs - y1abs) / (this.nodeCounts + 1);
  177.                                     const xstep = (x2abs - x1abs) / 4;
  178.                                     const ystep = (y2abs - y1abs) / 4;
  179.  
  180.                                     //const shift = (nodeIndex - (this.nodeCounts - 1) / 2);
  181.                                     const shift = 1;
  182.                                     let midxabs = ((x2abs + x1abs) / 2) - xstep * shift;
  183.                                     let midyabs = ((y2abs + y1abs) / 2) - ystep * shift;
  184.  
  185.                                     const weight = weights[inputNodeIndex][nodeIndex];
  186.  
  187.                                     const lineColor = weight >= 0 ? "blue" : "red";
  188.  
  189.                                     svg.append("line")
  190.                                         .attr("x1", x1abs)
  191.                                         .attr("y1", y1abs)
  192.                                         .attr("x2", x2abs)
  193.                                         .attr("y2", y2abs)
  194.                                         .attr("stroke", lineColor);
  195.  
  196.                                     // weights
  197.                                     svg.append("circle")
  198.                                         .attr("cx", midxabs)
  199.                                         .attr("cy", midyabs)
  200.                                         .attr("r", radiusWeight)
  201.                                         .attr("stroke", "green")
  202.                                         .attr("stroke-width", "1")
  203.                                         .attr("fill", "white");
  204.  
  205.                                     svg.append("text")
  206.                                         .attr("x", midxabs)
  207.                                         .attr("y", midyabs)
  208.                                         .attr("text-anchor", "middle")
  209.                                         .attr("dominant-baseline", "middle")
  210.                                         .attr("font-size", "1.3em")
  211.                                         .attr("stroke", "blue")
  212.                                         .attr("stroke-width", "1")
  213.                                         .attr("fill", "green")
  214.                                         .text(weight.toFixed(2));
  215.  
  216.                                     // circle around multi weight
  217.                                     svg.append("circle")
  218.                                         .attr("cx", midxabs + xscale(2))
  219.                                         .attr("cy", midyabs - yscale(1.5))
  220.                                         .attr("r", radiusWeight - 5)
  221.                                         .attr("stroke", "blue")
  222.                                         .attr("stroke-width", "1")
  223.                                         .attr("fill", "white");
  224.                                                                            
  225.                                     const weightMulInput = weights[inputNodeIndex][nodeIndex] * output[inputNodeIndex];
  226.                                     svg.append("text")
  227.                                         .attr("x", midxabs + xscale(2))
  228.                                         .attr("y", midyabs - yscale(1.5))
  229.                                         .attr("fill", "white")
  230.                                         .attr("text-anchor", "middle")
  231.                                         .attr("dominant-baseline", "middle")
  232.                                         .attr("font-size", "1em")
  233.                                         .attr("stroke", "green")
  234.                                         .attr("stroke-width", "1")
  235.                                         .attr("fill", "green")
  236.                                         .text(weightMulInput.toFixed(2));
  237.                                        
  238.                                     if (deltaWeights)
  239.                                     {
  240.                                         // circle around delta weight
  241.                                         svg.append("circle")
  242.                                             .attr("cx", midxabs)
  243.                                             .attr("cy", midyabs + yscale(3.3))
  244.                                             .attr("r", radiusWeight - 5)
  245.                                             .attr("stroke", "red")
  246.                                             .attr("stroke-width", "1")
  247.                                             .attr("fill", "white");
  248.  
  249.                                         const deltaWeight = deltaWeights[inputNodeIndex][nodeIndex];
  250.                                         svg.append("text")
  251.                                             .attr("x", midxabs)
  252.                                             .attr("y", midyabs + yscale(3.3))
  253.                                             .attr("fill", "white")
  254.                                             .attr("text-anchor", "middle")
  255.                                             .attr("dominant-baseline", "middle")
  256.                                             .attr("font-size", "1em")
  257.                                             .attr("stroke", "red")
  258.                                             .attr("stroke-width", "1")
  259.                                             .attr("fill", "red")
  260.                                             .text(deltaWeight.toFixed(2));
  261.  
  262.                                         // arrow
  263.                                         let rotate = "";
  264.                                         let color = "ff0000";
  265.                                         if (deltaWeight > 0)
  266.                                         {
  267.                                             color = "00ff00";
  268.                                             rotate = " rotate(180)";
  269.                                             midxabs += xscale(1.0);
  270.                                             midyabs += yscale(1.9);
  271.                                         }
  272.  
  273.                                         svg.append("g")
  274.                                             .attr("transform", "translate(" + (midxabs + xscale(1.0)) + ", " + (midyabs + yscale(2.2)) + ") scale(0.025, 0.025)" + rotate)
  275.                                             .append("path")
  276.                                             .attr("d", "M 680.00001,-27.858649 C 680.00001,-26.000219 330.00773,672.12987 329.42789,671.42805 C 327.4824,669.07328 -20.639627,-28.032809 -19.999107,-28.291219 C -19.580737,-28.459999 59.298143,5.3997002 155.28729,46.952552 L 329.81303,122.50321 L 504.05513,46.965887 C 599.8883,5.4203742 678.68037,-28.571419 679.14863,-28.571419 C 679.61689,-28.571419 680.00001,-28.250669 680.00001,-27.858649 z")
  277.                                             .attr("style", "fill:#" +color + ";fill-opacity:1;stroke:none")
  278.                                     }
  279.                                 }));
  280.                     }
  281.  
  282.                     // expecting result
  283.                     if (expectingVector) {
  284.                         const cx3 = xscale(xn + 10);
  285.                         const cy3 = yscale(yn);
  286.  
  287.                         svg.append("circle")
  288.                             .attr("cx", cx3)
  289.                             .attr("cy", cy3)
  290.                             .attr("r", radius)
  291.                             .attr("stroke", "black")
  292.                             .attr("stroke-width", "3")
  293.                             .attr("fill", "white");
  294.  
  295.                         svg.append("text")
  296.                             .attr("x", cx3)
  297.                             .attr("y", cy3)
  298.                             .attr("fill", "black")
  299.                             .attr("text-anchor", "middle")
  300.                             .attr("dominant-baseline", "middle")
  301.                             .attr("font-size", "3em")
  302.                             .text(expectingVector[i]);
  303.  
  304.                         svg.append("line")
  305.                             .attr("x1", cx + radius)
  306.                             .attr("y1", cy)
  307.                             .attr("x2", cx3 - radius)
  308.                             .attr("y2", cy3)
  309.                             .attr("stroke", "blue");
  310.                     }
  311.                 });
  312.  
  313.                 if (this.type == LayerTypes.OUTPUT) {
  314.                     const [xn, yn] = this.nodeCoords(this.layerIndex, this.nodeCounts + 1);
  315.                     const ex = xscale(xn);
  316.                     const ey = yscale(yn);
  317.  
  318.                     svg.append("text")
  319.                         .attr("x", ex + 20)
  320.                         .attr("y", ey)
  321.                         .attr("fill", "black")
  322.                         .attr("text-anchor", "middle")
  323.                         .attr("dominant-baseline", "middle")
  324.                         .attr("font-size", "2em")
  325.                         .text("E = " + error[0].toFixed(4));
  326.                 }
  327.             }
  328.         }
  329.  
  330.         class Runner {
  331.  
  332.             // [input, hidden, ..., hidden, output]
  333.             constructor(layerNumNodes) {
  334.                 this.layers = layerNumNodes.map((n, i) => {
  335.                     const l = new Layer(n);
  336.                     l.layerIndex = i;
  337.  
  338.                     if (i == 0)
  339.                     {
  340.                         l.type = LayerTypes.INPUT;
  341.                     }
  342.                     else if (i == (layerNumNodes.length - 1))
  343.                     {
  344.                         l.type = LayerTypes.OUTPUT;
  345.                     }
  346.                     else
  347.                     {
  348.                         l.type = LayerTypes.HIDDEN;
  349.                     }
  350.  
  351.                     return l;
  352.                 });
  353.             }
  354.  
  355.             createPresentation() {
  356.                 const margin = { top: 10, right: 10, bottom: 10, left: 10 };
  357.                 const width = svgWidth - margin.left - margin.right;
  358.                 const height = svgHeight - margin.top - margin.bottom;
  359.                 const svgwidth = svgWidth + margin.left + margin.right;
  360.                 const svgheight = svgHeight + margin.top + margin.bottom;
  361.  
  362.                 if (this.svg)
  363.                 {
  364.                     d3.select("#area").select("svg").remove();
  365.                 }
  366.  
  367.                 // svg
  368.                 this.svg =
  369.                     d3.select("#area")
  370.                         .append("svg")
  371.                         .attr("width", svgwidth)
  372.                         .attr("height", svgheight)
  373.                         .attr("style", 'border: solid 1px green;')
  374.                         .append("g")
  375.                         .attr("transform", "translate(" + margin.left + "," + margin.top + ")");
  376.  
  377.                 // axises
  378.                 this.x = d3.scaleLinear().domain([0, 100]).range([0, width]);
  379.                 this.y = d3.scaleLinear().domain([0, 100]).range([0, height]);
  380.  
  381.                 this.svg
  382.                     .append('g')
  383.                     .attr("transform", "translate(0," + height + ")")
  384.                     .call(d3.axisBottom(this.x));
  385.  
  386.                 this.svg
  387.                     .append('g')
  388.                     .call(d3.axisLeft(this.y));
  389.             }
  390.  
  391.             setInputVector(inputVector) {
  392.                 return this;
  393.             }
  394.  
  395.             setExpectingOutput(expectingVector) {
  396.                 return this;
  397.             }
  398.  
  399.             show() {
  400.                 this.createPresentation();
  401.                 this.layers.forEach(l => l.createGraph(this));
  402.             }
  403.         }
  404.  
  405.         const nn = new NNSimulation(2, 3)
  406.             .setInput([
  407.                 [0, 0, 1],
  408.                 [0, 1, 1],
  409.                 [1, 0, 0],
  410.                 [1, 1, 0],
  411.                 [1, 0, 1],
  412.                 [1, 1, 1],])
  413.             .setOutput(
  414.                 [0, 1, 0, 1, 1, 0])
  415.             .init()
  416.             .forward()
  417.             //.train(10000)
  418.             ;
  419.  
  420.         const runner = new Runner([3, 3, 3, 1]);
  421.        
  422.         runner.bind = () => {
  423.             // input
  424.             runner.layers[0].input = nn.X/*[inputData]*/;
  425.             runner.layers[0].inputAfterActivation = nn.X/*[inputData]*/;
  426.  
  427.             // hidden layers
  428.             for (let i = 1; i <= nn.num_hidden_layers; i++)
  429.            {
  430.                const layerIndex = i - 1;
  431.                runner.layers[i].output = runner.layers[i - 1].inputAfterActivation;
  432.                runner.layers[i].weights = nn.hidden_weights[layerIndex];
  433.                if (nn.hidden_weights_change)
  434.                {
  435.                    runner.layers[i].deltaWeights = nn.hidden_weights_change[layerIndex];
  436.                }
  437.  
  438.                runner.layers[i].input = nn.hidden_layer_outputs_before_activation[layerIndex]/*[inputData]*/;
  439.                runner.layers[i].inputAfterActivation = nn.hidden_layer_outputs[layerIndex]/*[inputData]*/;
  440.            }
  441.  
  442.            const outputLayer = nn.num_hidden_layers + 1;
  443.  
  444.            // output
  445.            runner.layers[outputLayer].output = runner.layers[nn.num_hidden_layers].inputAfterActivation;
  446.            runner.layers[outputLayer].weights = nn.output_weights;
  447.            if (nn.output_weights_change)
  448.            {
  449.                runner.layers[outputLayer].deltaWeights = nn.output_weights_change;
  450.            }
  451.  
  452.            runner.layers[outputLayer].input = nn.output_layer_outputs_before_activation/*[inputData]*/;
  453.            runner.layers[outputLayer].inputAfterActivation = nn.output_layer_outputs/*[inputData]*/;
  454.            runner.layers[outputLayer].expecting = nn.Y/*[inputData]*/;
  455.            if (nn.output_error)
  456.            {
  457.                runner.layers[outputLayer].error = nn.output_error/*[inputData]*/;
  458.            }
  459.        }
  460.  
  461.        runner.bind();
  462.        runner.show();
  463.    </script>
  464. </body>
  465.  
  466. </html>
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement