View difference between Paste ID: mfQFhqLW and wD6PrDHw
SHOW: | | - or go back to the newest paste.
1
public void trainState(double[] lastObservation, double[] hiddenEnergy, double[] lastOutput, int chosenAction, double target)
2
    {
3
        _generation++;
4
        double[] inputs = addBias(lastObservation);
5
        Tuple< double[], double[]> value = getValue(lastObservation);
6
        hiddenEnergy = value._y;
7
        lastOutput = value._x;
8
        double alpha = 0.1;
9
        double errorOutput = (1 - lastOutput[chosenAction] * lastOutput[chosenAction]) * (target - lastOutput[chosenAction]);
10
        double errorHidden;
11
        
12
        for (int j = 0; j < _hiddenSize; j++)
13
        {
14
            errorHidden = (1 - hiddenEnergy[j] * hiddenEnergy[j]) * errorOutput * _layerWeights[1][j][chosenAction];
15
16
            for (int i = 0; i < _inputSize; i++)
17
            {
18
                updateWeight(0, errorHidden * inputs[i], errorOutput * _layerWeights[1][j][chosenAction], i, j);
19
            }
20
        }
21
22
        for (int i = 0; i < _hiddenSize; i++)
23
        {
24
            updateWeight(1, errorOutput * hiddenEnergy[i], errorOutput, i, chosenAction);
25
        }
26
    }
27
28-
private void updateWeight(int layer, double gradient, double error, int hiddenNode, int outputNode)
28+
private void updateWeight(int layer, double gradient, double error, int leftNode, int rightNode)
29
    {
30-
        final int change = sign(gradient * _lastGradient[layer][hiddenNode][outputNode]);
30+
        final int change = sign(gradient * _lastGradient[layer][leftNode][rightNode]);
31
        double weightChange = 0;
32-
        double delta = _lastUpdateValue[layer][hiddenNode][outputNode];
32+
        double delta = _lastUpdateValue[layer][leftNode][rightNode];
33
34
        if (change > 0)
35
        {
36
            delta = Math.min(delta * POSITIVE_ETA, DELTA_MAX);
37
            weightChange = -sign(gradient) * delta;
38-
            _lastGradient[layer][hiddenNode][outputNode] = gradient;
38+
            _lastGradient[layer][leftNode][rightNode] = gradient;
39
        } else if (change < 0)
40
        {
41
            delta = Math.max(delta * NEGATIVE_ETA, DELTA_MIN);
42
43-
            if (error > _lastError[layer][hiddenNode][outputNode])
43+
            if (error > _lastError[layer][leftNode][rightNode])
44
            {
45-
                weightChange = -_lastWeightChanges[layer][hiddenNode][outputNode];
45+
                weightChange = -_lastWeightChanges[layer][leftNode][rightNode];
46
            }
47
48-
            _lastGradient[layer][hiddenNode][outputNode] = 0.;
48+
            _lastGradient[layer][leftNode][rightNode] = 0.;
49
        } else if (change == 0)
50
        {
51
            weightChange = -sign(gradient) * delta;
52-
            _lastGradient[layer][hiddenNode][outputNode] = gradient;
52+
            _lastGradient[layer][leftNode][rightNode] = gradient;
53
        }
54
55-
        _lastUpdateValue[layer][hiddenNode][outputNode] = delta;
55+
        _lastUpdateValue[layer][leftNode][rightNode] = delta;
56-
        _layerWeights[layer][hiddenNode][outputNode] += weightChange;
56+
        _layerWeights[layer][leftNode][rightNode] += weightChange;
57-
        _lastWeightChanges[layer][hiddenNode][outputNode] = weightChange;
57+
        _lastWeightChanges[layer][leftNode][rightNode] = weightChange;
58-
        _lastError[layer][hiddenNode][outputNode] = error;
58+
        _lastError[layer][leftNode][rightNode] = error;
59
    }