Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- <html>
- <!DOCTYPE html>
- <head>
- <meta http-equiv="Content-Type" content="text/html; charset=utf-8" />
- <title>NN Experiment</title>
- <script src="https://d3js.org/d3.v6.min.js"></script>
- <script src="math_lib.js"></script>
- <script src="nn.js"></script>
- <style>
- html,
- body {
- overflow: hidden;
- width: 100%;
- height: 100%;
- margin: 0;
- padding: 0;
- }
- </style>
- </head>
- <body>
- <script>let inputData = 0;</script>
- <span>Layer: </span>
- <button type="button" onclick="decrease()"><</button>
- <span id="layerNumberSpan">0</span>
- <button type="button" onclick="increase()">></button>
- <span id="layerNumberSpan"> </span>
- <input type="number" style="width: 80px;" id="countTrains" value="1" step="1000"></input>
- <button type="button" onclick="train()">Train</button>
- <span>Learning Speed: </span>
- <input type="number" style="width: 80px;" id="learningSpeed" value="0.1" step="0.1" onchange="set_learning_speed()"></input>
- <div id="area" style='border: solid 1px red;'></div>
- <script>
- function increase()
- {
- inputData++;
- if (inputData >= nn.X.length)
- {
- inputData = nn.X.length - 1;
- }
- repopulate();
- }
- function decrease()
- {
- inputData--;
- if (inputData < 0)
- {
- inputData = 0;
- }
- repopulate();
- }
- function repopulate()
- {
- document.getElementById("layerNumberSpan").innerText = inputData;
- runner.show();
- }
- function train()
- {
- nn.train(Number(document.getElementById("countTrains").value));
- runner.bind();
- repopulate();
- }
- function set_learning_speed()
- {
- nn.alpha = Number(document.getElementById("learningSpeed").value);
- }
- const svgWidth = 1600;
- const svgHeight = 1024;
- const radius = 40;
- const radiusWeight = 20;
- const LayerTypes = {
- INPUT: 'input',
- HIDDEN: 'hidden',
- OUTPUT: 'output'
- }
- class Layer {
- constructor(count) {
- this.nodeCounts = count;
- }
- nodeCoords(layerIndex, index) {
- return [layerIndex * 25 + 5, index * 30 + 20];
- }
- createGraph(runner) {
- const svg = runner.svg;
- const xscale = runner.x;
- const yscale = runner.y;
- const weights = this.weights;
- const deltaWeights = this.deltaWeights;
- const input = this.input[inputData];
- const inputAfterActivation = this.inputAfterActivation[inputData];
- const output = this.type != LayerTypes.INPUT ? this.output[inputData] : undefined;
- const expectingVector = this.type == LayerTypes.OUTPUT ? this.expecting[inputData] : undefined;
- const error = this.type == LayerTypes.OUTPUT ? this.error[inputData] : undefined;
- range(this.nodeCounts).forEach((_, i) => {
- const [xn, yn] = this.nodeCoords(this.layerIndex, i);
- const cx = xscale(xn);
- const cy = yscale(yn);
- const cx2 = cx + xscale(2);
- const cy2 = cy + yscale(-4);
- // total weight
- svg.append("circle")
- .attr("cx", cx)
- .attr("cy", cy)
- .attr("r", radius)
- .attr("stroke", "black")
- .attr("stroke-width", "3")
- .attr("fill", "green");
- svg.append("text")
- .attr("x", cx)
- .attr("y", cy)
- .attr("fill", "white")
- .attr("text-anchor", "middle")
- .attr("dominant-baseline", "middle")
- .attr("font-size", "2em")
- .text(input[i].toFixed(2));
- // output
- svg.append("circle")
- .attr("cx", cx2)
- .attr("cy", cy2)
- .attr("r", radius - 20)
- .attr("stroke", "black")
- .attr("stroke-width", "3")
- .attr("fill", "white");
- svg.append("text")
- .attr("x", cx2)
- .attr("y", cy2)
- .attr("fill", "black")
- .attr("text-anchor", "middle")
- .attr("dominant-baseline", "middle")
- .attr("font-size", "1em")
- .text(inputAfterActivation[i].toFixed(2));
- // draw lines
- if (this.layerIndex > 0) {
- const inputLayerIndex = this.layerIndex - 1;
- const inputCount = weights.length;
- range(this.nodeCounts)
- .forEach((_, nodeIndex) => range(inputCount)
- .forEach((_, inputNodeIndex) => {
- const [x1, y1] = this.nodeCoords(this.layerIndex, nodeIndex);
- const [x2, y2] = this.nodeCoords(inputLayerIndex, inputNodeIndex);
- const x1abs = xscale(x1) - radius;
- const y1abs = yscale(y1);
- const x2abs = xscale(x2) + radius;
- const y2abs = yscale(y2);
- //const xstep = (x2abs - x1abs) / (this.nodeCounts + 1);
- //const ystep = (y2abs - y1abs) / (this.nodeCounts + 1);
- const xstep = (x2abs - x1abs) / 4;
- const ystep = (y2abs - y1abs) / 4;
- //const shift = (nodeIndex - (this.nodeCounts - 1) / 2);
- const shift = 1;
- let midxabs = ((x2abs + x1abs) / 2) - xstep * shift;
- let midyabs = ((y2abs + y1abs) / 2) - ystep * shift;
- const weight = weights[inputNodeIndex][nodeIndex];
- const lineColor = weight >= 0 ? "blue" : "red";
- svg.append("line")
- .attr("x1", x1abs)
- .attr("y1", y1abs)
- .attr("x2", x2abs)
- .attr("y2", y2abs)
- .attr("stroke", lineColor);
- // weights
- svg.append("circle")
- .attr("cx", midxabs)
- .attr("cy", midyabs)
- .attr("r", radiusWeight)
- .attr("stroke", "green")
- .attr("stroke-width", "1")
- .attr("fill", "white");
- svg.append("text")
- .attr("x", midxabs)
- .attr("y", midyabs)
- .attr("text-anchor", "middle")
- .attr("dominant-baseline", "middle")
- .attr("font-size", "1.3em")
- .attr("stroke", "blue")
- .attr("stroke-width", "1")
- .attr("fill", "green")
- .text(weight.toFixed(2));
- // circle around multi weight
- svg.append("circle")
- .attr("cx", midxabs + xscale(2))
- .attr("cy", midyabs - yscale(1.5))
- .attr("r", radiusWeight - 5)
- .attr("stroke", "blue")
- .attr("stroke-width", "1")
- .attr("fill", "white");
- const weightMulInput = weights[inputNodeIndex][nodeIndex] * output[inputNodeIndex];
- svg.append("text")
- .attr("x", midxabs + xscale(2))
- .attr("y", midyabs - yscale(1.5))
- .attr("fill", "white")
- .attr("text-anchor", "middle")
- .attr("dominant-baseline", "middle")
- .attr("font-size", "1em")
- .attr("stroke", "green")
- .attr("stroke-width", "1")
- .attr("fill", "green")
- .text(weightMulInput.toFixed(2));
- if (deltaWeights)
- {
- // circle around delta weight
- svg.append("circle")
- .attr("cx", midxabs)
- .attr("cy", midyabs + yscale(3.3))
- .attr("r", radiusWeight - 5)
- .attr("stroke", "red")
- .attr("stroke-width", "1")
- .attr("fill", "white");
- const deltaWeight = deltaWeights[inputNodeIndex][nodeIndex];
- svg.append("text")
- .attr("x", midxabs)
- .attr("y", midyabs + yscale(3.3))
- .attr("fill", "white")
- .attr("text-anchor", "middle")
- .attr("dominant-baseline", "middle")
- .attr("font-size", "1em")
- .attr("stroke", "red")
- .attr("stroke-width", "1")
- .attr("fill", "red")
- .text(deltaWeight.toFixed(2));
- // arrow
- let rotate = "";
- let color = "ff0000";
- if (deltaWeight > 0)
- {
- color = "00ff00";
- rotate = " rotate(180)";
- midxabs += xscale(1.0);
- midyabs += yscale(1.9);
- }
- svg.append("g")
- .attr("transform", "translate(" + (midxabs + xscale(1.0)) + ", " + (midyabs + yscale(2.2)) + ") scale(0.025, 0.025)" + rotate)
- .append("path")
- .attr("d", "M 680.00001,-27.858649 C 680.00001,-26.000219 330.00773,672.12987 329.42789,671.42805 C 327.4824,669.07328 -20.639627,-28.032809 -19.999107,-28.291219 C -19.580737,-28.459999 59.298143,5.3997002 155.28729,46.952552 L 329.81303,122.50321 L 504.05513,46.965887 C 599.8883,5.4203742 678.68037,-28.571419 679.14863,-28.571419 C 679.61689,-28.571419 680.00001,-28.250669 680.00001,-27.858649 z")
- .attr("style", "fill:#" +color + ";fill-opacity:1;stroke:none")
- }
- }));
- }
- // expecting result
- if (expectingVector) {
- const cx3 = xscale(xn + 10);
- const cy3 = yscale(yn);
- svg.append("circle")
- .attr("cx", cx3)
- .attr("cy", cy3)
- .attr("r", radius)
- .attr("stroke", "black")
- .attr("stroke-width", "3")
- .attr("fill", "white");
- svg.append("text")
- .attr("x", cx3)
- .attr("y", cy3)
- .attr("fill", "black")
- .attr("text-anchor", "middle")
- .attr("dominant-baseline", "middle")
- .attr("font-size", "3em")
- .text(expectingVector[i]);
- svg.append("line")
- .attr("x1", cx + radius)
- .attr("y1", cy)
- .attr("x2", cx3 - radius)
- .attr("y2", cy3)
- .attr("stroke", "blue");
- }
- });
- if (this.type == LayerTypes.OUTPUT) {
- const [xn, yn] = this.nodeCoords(this.layerIndex, this.nodeCounts + 1);
- const ex = xscale(xn);
- const ey = yscale(yn);
- svg.append("text")
- .attr("x", ex + 20)
- .attr("y", ey)
- .attr("fill", "black")
- .attr("text-anchor", "middle")
- .attr("dominant-baseline", "middle")
- .attr("font-size", "2em")
- .text("E = " + error[0].toFixed(4));
- }
- }
- }
- class Runner {
- // [input, hidden, ..., hidden, output]
- constructor(layerNumNodes) {
- this.layers = layerNumNodes.map((n, i) => {
- const l = new Layer(n);
- l.layerIndex = i;
- if (i == 0)
- {
- l.type = LayerTypes.INPUT;
- }
- else if (i == (layerNumNodes.length - 1))
- {
- l.type = LayerTypes.OUTPUT;
- }
- else
- {
- l.type = LayerTypes.HIDDEN;
- }
- return l;
- });
- }
- createPresentation() {
- const margin = { top: 10, right: 10, bottom: 10, left: 10 };
- const width = svgWidth - margin.left - margin.right;
- const height = svgHeight - margin.top - margin.bottom;
- const svgwidth = svgWidth + margin.left + margin.right;
- const svgheight = svgHeight + margin.top + margin.bottom;
- if (this.svg)
- {
- d3.select("#area").select("svg").remove();
- }
- // svg
- this.svg =
- d3.select("#area")
- .append("svg")
- .attr("width", svgwidth)
- .attr("height", svgheight)
- .attr("style", 'border: solid 1px green;')
- .append("g")
- .attr("transform", "translate(" + margin.left + "," + margin.top + ")");
- // axises
- this.x = d3.scaleLinear().domain([0, 100]).range([0, width]);
- this.y = d3.scaleLinear().domain([0, 100]).range([0, height]);
- this.svg
- .append('g')
- .attr("transform", "translate(0," + height + ")")
- .call(d3.axisBottom(this.x));
- this.svg
- .append('g')
- .call(d3.axisLeft(this.y));
- }
- setInputVector(inputVector) {
- return this;
- }
- setExpectingOutput(expectingVector) {
- return this;
- }
- show() {
- this.createPresentation();
- this.layers.forEach(l => l.createGraph(this));
- }
- }
- const nn = new NNSimulation(2, 3)
- .setInput([
- [0, 0, 1],
- [0, 1, 1],
- [1, 0, 0],
- [1, 1, 0],
- [1, 0, 1],
- [1, 1, 1],])
- .setOutput(
- [0, 1, 0, 1, 1, 0])
- .init()
- .forward()
- //.train(10000)
- ;
- const runner = new Runner([3, 3, 3, 1]);
- runner.bind = () => {
- // input
- runner.layers[0].input = nn.X/*[inputData]*/;
- runner.layers[0].inputAfterActivation = nn.X/*[inputData]*/;
- // hidden layers
- for (let i = 1; i <= nn.num_hidden_layers; i++)
- {
- const layerIndex = i - 1;
- runner.layers[i].output = runner.layers[i - 1].inputAfterActivation;
- runner.layers[i].weights = nn.hidden_weights[layerIndex];
- if (nn.hidden_weights_change)
- {
- runner.layers[i].deltaWeights = nn.hidden_weights_change[layerIndex];
- }
- runner.layers[i].input = nn.hidden_layer_outputs_before_activation[layerIndex]/*[inputData]*/;
- runner.layers[i].inputAfterActivation = nn.hidden_layer_outputs[layerIndex]/*[inputData]*/;
- }
- const outputLayer = nn.num_hidden_layers + 1;
- // output
- runner.layers[outputLayer].output = runner.layers[nn.num_hidden_layers].inputAfterActivation;
- runner.layers[outputLayer].weights = nn.output_weights;
- if (nn.output_weights_change)
- {
- runner.layers[outputLayer].deltaWeights = nn.output_weights_change;
- }
- runner.layers[outputLayer].input = nn.output_layer_outputs_before_activation/*[inputData]*/;
- runner.layers[outputLayer].inputAfterActivation = nn.output_layer_outputs/*[inputData]*/;
- runner.layers[outputLayer].expecting = nn.Y/*[inputData]*/;
- if (nn.output_error)
- {
- runner.layers[outputLayer].error = nn.output_error/*[inputData]*/;
- }
- }
- runner.bind();
- runner.show();
- </script>
- </body>
- </html>
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement