Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import java.util.Random;
- public class AdoptiveStepOR {
- static double[][] getTruthTable(int n, double top) {
- double[][] points = new double[64][n];
- for (int i = 0; i < points.length; i++) {
- for (int j = 0; j < n; j++) {
- points[i][j] = (i >> j & 1) * top;
- }
- }
- return points;
- }
- static double[] getRandomWeights(int n) {
- Random r = new Random();
- double[] w = new double[n];
- for (int i = 0; i < w.length; i++) {
- w[i] = r.nextDouble();
- }
- return w;
- }
- static double weightedSum(double[] x, double[] w, double t){
- double sum = 0;
- for (int i = 0; i < x.length; i++) {
- sum += x[i]*w[i];
- }
- return sum - t;
- }
- static double f(double sum){
- return sum > 0 ? 1.0 : 0.0;
- }
- static double sample(double[] x){
- for (double d : x) {
- if(d > 0) return 1;
- }
- return 0;
- }
- static double getNewSpeed(double currentSpeed, double[] x){
- double newSpeed = 1;
- for (int i = 0; i < x.length; i++) {
- newSpeed += x[i]*x[i];
- }
- return 1/newSpeed;
- }
- static double[] getNewWeights(double[] x, double[] w, double diff){
- double[] newW = new double[w.length];
- for (int i = 0; i < w.length; i++) {
- newW[i] = w[i] - diff*x[i];
- }
- return newW;
- }
- static double getNewThreshold(double t,double diff){
- return t + diff;
- }
- public static void main(String[] args) {
- int n = 6;
- double learnigSpeed = 1.0, top = 9, t = new Random().nextDouble()*top - 2;
- double[] w = getRandomWeights(n);
- double[][] points = getTruthTable(n, top);
- System.out.print("sum =");
- for (int i = 0; i < w.length; i++) {
- System.out.printf(" %6.3f*x%d +", w[i], i + 1);
- }
- System.out.printf("- %4.2f\n", t);
- for (int j = 1;; j++) {
- boolean check = true;
- for (int i = 0; i < 64; i++) {
- double[] x = points[i];
- double sum = weightedSum(x, w, t),
- err = f(sum) - sample(x),
- diff = learnigSpeed* err;
- w = getNewWeights(x, w, diff);
- t = getNewThreshold(t, diff);
- learnigSpeed = getNewSpeed(learnigSpeed, x);
- System.out.printf("alpha = %6.4f, sum = %6.3f, error = %4.1f, t = %4.2f\n", learnigSpeed, sum, err, t);
- check &= err == 0.0;
- }
- System.out.println("========= " + j);
- if (check)
- break;
- }
- System.out.print("sum =");
- for (int i = 0; i < w.length; i++) {
- System.out.printf(" %6.3f*x%d +", w[i], i + 1);
- }
- System.out.printf("- %4.2f\n", t);
- System.out.println(f(weightedSum(points[0], w, top)));
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement