SHOW:
|
|
- or go back to the newest paste.
1 | public void trainState(double[] lastObservation, double[] hiddenEnergy, double[] lastOutput, int chosenAction, double target) | |
2 | { | |
3 | _generation++; | |
4 | double[] inputs = addBias(lastObservation); | |
5 | Tuple< double[], double[]> value = getValue(lastObservation); | |
6 | hiddenEnergy = value._y; | |
7 | lastOutput = value._x; | |
8 | double alpha = 0.1; | |
9 | double errorOutput = (1 - lastOutput[chosenAction] * lastOutput[chosenAction]) * (target - lastOutput[chosenAction]); | |
10 | double errorHidden; | |
11 | ||
12 | for (int j = 0; j < _hiddenSize; j++) | |
13 | { | |
14 | errorHidden = (1 - hiddenEnergy[j] * hiddenEnergy[j]) * errorOutput * _layerWeights[1][j][chosenAction]; | |
15 | ||
16 | for (int i = 0; i < _inputSize; i++) | |
17 | { | |
18 | updateWeight(0, errorHidden * inputs[i], errorOutput * _layerWeights[1][j][chosenAction], i, j); | |
19 | } | |
20 | } | |
21 | ||
22 | for (int i = 0; i < _hiddenSize; i++) | |
23 | { | |
24 | updateWeight(1, errorOutput * hiddenEnergy[i], errorOutput, i, chosenAction); | |
25 | } | |
26 | } | |
27 | ||
28 | - | private void updateWeight(int layer, double gradient, double error, int hiddenNode, int outputNode) |
28 | + | private void updateWeight(int layer, double gradient, double error, int leftNode, int rightNode) |
29 | { | |
30 | - | final int change = sign(gradient * _lastGradient[layer][hiddenNode][outputNode]); |
30 | + | final int change = sign(gradient * _lastGradient[layer][leftNode][rightNode]); |
31 | double weightChange = 0; | |
32 | - | double delta = _lastUpdateValue[layer][hiddenNode][outputNode]; |
32 | + | double delta = _lastUpdateValue[layer][leftNode][rightNode]; |
33 | ||
34 | if (change > 0) | |
35 | { | |
36 | delta = Math.min(delta * POSITIVE_ETA, DELTA_MAX); | |
37 | weightChange = -sign(gradient) * delta; | |
38 | - | _lastGradient[layer][hiddenNode][outputNode] = gradient; |
38 | + | _lastGradient[layer][leftNode][rightNode] = gradient; |
39 | } else if (change < 0) | |
40 | { | |
41 | delta = Math.max(delta * NEGATIVE_ETA, DELTA_MIN); | |
42 | ||
43 | - | if (error > _lastError[layer][hiddenNode][outputNode]) |
43 | + | if (error > _lastError[layer][leftNode][rightNode]) |
44 | { | |
45 | - | weightChange = -_lastWeightChanges[layer][hiddenNode][outputNode]; |
45 | + | weightChange = -_lastWeightChanges[layer][leftNode][rightNode]; |
46 | } | |
47 | ||
48 | - | _lastGradient[layer][hiddenNode][outputNode] = 0.; |
48 | + | _lastGradient[layer][leftNode][rightNode] = 0.; |
49 | } else if (change == 0) | |
50 | { | |
51 | weightChange = -sign(gradient) * delta; | |
52 | - | _lastGradient[layer][hiddenNode][outputNode] = gradient; |
52 | + | _lastGradient[layer][leftNode][rightNode] = gradient; |
53 | } | |
54 | ||
55 | - | _lastUpdateValue[layer][hiddenNode][outputNode] = delta; |
55 | + | _lastUpdateValue[layer][leftNode][rightNode] = delta; |
56 | - | _layerWeights[layer][hiddenNode][outputNode] += weightChange; |
56 | + | _layerWeights[layer][leftNode][rightNode] += weightChange; |
57 | - | _lastWeightChanges[layer][hiddenNode][outputNode] = weightChange; |
57 | + | _lastWeightChanges[layer][leftNode][rightNode] = weightChange; |
58 | - | _lastError[layer][hiddenNode][outputNode] = error; |
58 | + | _lastError[layer][leftNode][rightNode] = error; |
59 | } |