Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- public void trainState(double[] lastObservation, double[] hiddenEnergy, double[] lastOutput, int chosenAction, double target)
- {
- _generation++;
- double[] inputs = addBias(lastObservation);
- Tuple< double[], double[]> value = getValue(lastObservation);
- hiddenEnergy = value._y;
- lastOutput = value._x;
- double alpha = 0.1;
- double errorOutput = (1 - lastOutput[chosenAction] * lastOutput[chosenAction]) * (target - lastOutput[chosenAction]);
- double errorHidden;
- for (int j = 0; j < _hiddenSize; j++)
- {
- errorHidden = (1 - hiddenEnergy[j] * hiddenEnergy[j]) * errorOutput * _layerWeights[1][j][chosenAction];
- for (int i = 0; i < _inputSize; i++)
- {
- updateWeight(0, errorHidden * inputs[i], errorOutput * _layerWeights[1][j][chosenAction], i, j);
- }
- }
- for (int i = 0; i < _hiddenSize; i++)
- {
- updateWeight(1, errorOutput * hiddenEnergy[i], errorOutput, i, chosenAction);
- }
- }
- private void updateWeight(int layer, double gradient, double error, int leftNode, int rightNode)
- {
- final int change = sign(gradient * _lastGradient[layer][leftNode][rightNode]);
- double weightChange = 0;
- double delta = _lastUpdateValue[layer][leftNode][rightNode];
- if (change > 0)
- {
- delta = Math.min(delta * POSITIVE_ETA, DELTA_MAX);
- weightChange = -sign(gradient) * delta;
- _lastGradient[layer][leftNode][rightNode] = gradient;
- } else if (change < 0)
- {
- delta = Math.max(delta * NEGATIVE_ETA, DELTA_MIN);
- if (error > _lastError[layer][leftNode][rightNode])
- {
- weightChange = -_lastWeightChanges[layer][leftNode][rightNode];
- }
- _lastGradient[layer][leftNode][rightNode] = 0.;
- } else if (change == 0)
- {
- weightChange = -sign(gradient) * delta;
- _lastGradient[layer][leftNode][rightNode] = gradient;
- }
- _lastUpdateValue[layer][leftNode][rightNode] = delta;
- _layerWeights[layer][leftNode][rightNode] += weightChange;
- _lastWeightChanges[layer][leftNode][rightNode] = weightChange;
- _lastError[layer][leftNode][rightNode] = error;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement