Advertisement
hishlishter

Laba5

Mar 12th, 2025
117
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 12.69 KB | Source Code | 0 0
  1. package org.example;
  2.  
  3. public class Main {
  4.     // Количество нейронов по слоям
  5.     static final int INPUT_NEURONS = 4;    // температура, влажность, скорость ветра, давление
  6.     static final int HIDDEN_NEURONS = 3;
  7.     static final int OUTPUT_NEURONS = 4;   // 4 класса погоды
  8.     static final double LEARN_RATE = 0.2;
  9.     static final int MAX_SAMPLES = 16;
  10.  
  11.     // Весовые коэффициенты (с учётом смещений)
  12.     double[][] wih = new double[INPUT_NEURONS + 1][HIDDEN_NEURONS]; // вход -> скрытый слой
  13.     double[][] who = new double[HIDDEN_NEURONS + 1][OUTPUT_NEURONS]; // скрытый -> выходной слой
  14.  
  15.     // Значения нейронов
  16.     double[] inputs = new double[INPUT_NEURONS];
  17.     double[] hidden = new double[HIDDEN_NEURONS];
  18.     double[] target = new double[OUTPUT_NEURONS];
  19.     double[] actual = new double[OUTPUT_NEURONS];
  20.  
  21.     // Ошибки нейронов
  22.     double[] erro = new double[OUTPUT_NEURONS];
  23.     double[] errh = new double[HIDDEN_NEURONS];
  24.  
  25.     // Массив с названиями погодных условий (индексы: 0 – "Солнечно", 1 – "Облачно", 2 – "Дождливо", 3 – "Шторм")
  26.     String[] conditions = {"Солнечно", "Облачно", "Дождливо", "Шторм"};
  27.  
  28.     // Класс, представляющий обучающий пример
  29.     static class Sample {
  30.         double temperature;  // температура (например, °C)
  31.         double humidity;     // влажность (в %)
  32.         double windSpeed;    // скорость ветра (м/с)
  33.         double pressure;     // давление (гПа)
  34.         double[] out;        // one-hot представление класса
  35.  
  36.         Sample(double temperature, double humidity, double windSpeed, double pressure, double[] out) {
  37.             this.temperature = temperature;
  38.             this.humidity = humidity;
  39.             this.windSpeed = windSpeed;
  40.             this.pressure = pressure;
  41.             this.out = out;
  42.         }
  43.     }
  44.  
  45.     // Обучающий набор (16 примеров: по 4 для каждого класса)
  46.     Sample[] samples = new Sample[] {
  47.             // "Солнечно": высокая температура, низкая влажность, слабый ветер, высокое давление
  48.             new Sample(30, 20, 5, 1020, new double[]{1.0, 0.0, 0.0, 0.0}),
  49.             new Sample(32, 18, 4, 1018, new double[]{1.0, 0.0, 0.0, 0.0}),
  50.             new Sample(29, 22, 6, 1022, new double[]{1.0, 0.0, 0.0, 0.0}),
  51.             new Sample(31, 19, 5, 1021, new double[]{1.0, 0.0, 0.0, 0.0}),
  52.  
  53.             // "Облачно": умеренная температура, средняя влажность, слабый ветер, среднее давление
  54.             new Sample(22, 50, 3, 1012, new double[]{0.0, 1.0, 0.0, 0.0}),
  55.             new Sample(23, 48, 4, 1013, new double[]{0.0, 1.0, 0.0, 0.0}),
  56.             new Sample(21, 52, 3, 1011, new double[]{0.0, 1.0, 0.0, 0.0}),
  57.             new Sample(22, 50, 2, 1012, new double[]{0.0, 1.0, 0.0, 0.0}),
  58.  
  59.             // "Дождливо": ниже температура, высокая влажность, умеренный ветер, низкое давление
  60.             new Sample(18, 80, 10, 1005, new double[]{0.0, 0.0, 1.0, 0.0}),
  61.             new Sample(17, 85, 9, 1004, new double[]{0.0, 0.0, 1.0, 0.0}),
  62.             new Sample(19, 78, 11, 1006, new double[]{0.0, 0.0, 1.0, 0.0}),
  63.             new Sample(18, 82, 10, 1005, new double[]{0.0, 0.0, 1.0, 0.0}),
  64.  
  65.             // "Шторм": умеренная или низкая температура, очень высокая влажность, сильный ветер, очень низкое давление
  66.             new Sample(16, 90, 20, 995, new double[]{0.0, 0.0, 0.0, 1.0}),
  67.             new Sample(15, 92, 22, 993, new double[]{0.0, 0.0, 0.0, 1.0}),
  68.             new Sample(16, 88, 19, 996, new double[]{0.0, 0.0, 0.0, 1.0}),
  69.             new Sample(15, 91, 21, 994, new double[]{0.0, 0.0, 0.0, 1.0})
  70.     };
  71.  
  72.     // Фиксированные минимальные и максимальные значения для каждого входного параметра
  73.     // (значения подобраны на основе обучающего набора)
  74.     static final double TEMP_MIN = 15;    // °C
  75.     static final double TEMP_MAX = 32;    // °C
  76.     static final double HUMIDITY_MIN = 18; // %
  77.     static final double HUMIDITY_MAX = 92; // %
  78.     static final double WIND_MIN = 2;      // м/с
  79.     static final double WIND_MAX = 22;     // м/с
  80.     static final double PRESSURE_MIN = 993;  // гПа
  81.     static final double PRESSURE_MAX = 1022; // гПа
  82.  
  83.     // Нормализация входных данных: перевод значений в диапазон [0,1]
  84.     void normalizeInputs() {
  85.         inputs[0] = (inputs[0] - TEMP_MIN) / (TEMP_MAX - TEMP_MIN);           // температура
  86.         inputs[1] = (inputs[1] - HUMIDITY_MIN) / (HUMIDITY_MAX - HUMIDITY_MIN); // влажность
  87.         inputs[2] = (inputs[2] - WIND_MIN) / (WIND_MAX - WIND_MIN);             // скорость ветра
  88.         inputs[3] = (inputs[3] - PRESSURE_MIN) / (PRESSURE_MAX - PRESSURE_MIN);  // давление
  89.     }
  90.  
  91.     // Инициализация весов случайными значениями в диапазоне [-0.5, 0.5]
  92.     void assignRandomWeights() {
  93.         for (int inp = 0; inp < INPUT_NEURONS + 1; inp++) {
  94.             for (int hid = 0; hid < HIDDEN_NEURONS; hid++) {
  95.                 wih[inp][hid] = Math.random() - 0.5;
  96.             }
  97.         }
  98.         for (int hid = 0; hid < HIDDEN_NEURONS + 1; hid++) {
  99.             for (int out = 0; out < OUTPUT_NEURONS; out++) {
  100.                 who[hid][out] = Math.random() - 0.5;
  101.             }
  102.         }
  103.     }
  104.  
  105.     // Функция активации (сигмоида) и её производная
  106.     double sigmoid(double val) {
  107.         return 1.0 / (1.0 + Math.exp(-val));
  108.     }
  109.  
  110.     double sigmoidDerivative(double val) {
  111.         return val * (1.0 - val);
  112.     }
  113.  
  114.     // Прямое распространение сигнала по сети
  115.     void feedForward() {
  116.         // Вычисление выхода скрытого слоя
  117.         for (int hid = 0; hid < HIDDEN_NEURONS; hid++) {
  118.             double sum = 0.0;
  119.             for (int inp = 0; inp < INPUT_NEURONS; inp++) {
  120.                 sum += inputs[inp] * wih[inp][hid];
  121.             }
  122.             // Добавление смещения
  123.             sum += wih[INPUT_NEURONS][hid];
  124.             hidden[hid] = sigmoid(sum);
  125.         }
  126.         // Вычисление выхода выходного слоя
  127.         for (int out = 0; out < OUTPUT_NEURONS; out++) {
  128.             double sum = 0.0;
  129.             for (int hid = 0; hid < HIDDEN_NEURONS; hid++) {
  130.                 sum += hidden[hid] * who[hid][out];
  131.             }
  132.             // Добавление смещения
  133.             sum += who[HIDDEN_NEURONS][out];
  134.             actual[out] = sigmoid(sum);
  135.         }
  136.     }
  137.  
  138.     // Алгоритм обратного распространения ошибки
  139.     void backPropagate() {
  140.         // Вычисление ошибки на выходном слое
  141.         for (int out = 0; out < OUTPUT_NEURONS; out++) {
  142.             erro[out] = (target[out] - actual[out]) * sigmoidDerivative(actual[out]);
  143.         }
  144.         // Вычисление ошибки на скрытом слое
  145.         for (int hid = 0; hid < HIDDEN_NEURONS; hid++) {
  146.             errh[hid] = 0.0;
  147.             for (int out = 0; out < OUTPUT_NEURONS; out++) {
  148.                 errh[hid] += erro[out] * who[hid][out];
  149.             }
  150.             errh[hid] *= sigmoidDerivative(hidden[hid]);
  151.         }
  152.         // Обновление весов выходного слоя
  153.         for (int out = 0; out < OUTPUT_NEURONS; out++) {
  154.             for (int hid = 0; hid < HIDDEN_NEURONS; hid++) {
  155.                 who[hid][out] += LEARN_RATE * erro[out] * hidden[hid];
  156.             }
  157.             // Обновление смещения
  158.             who[HIDDEN_NEURONS][out] += LEARN_RATE * erro[out];
  159.         }
  160.         // Обновление весов скрытого слоя
  161.         for (int hid = 0; hid < HIDDEN_NEURONS; hid++) {
  162.             for (int inp = 0; inp < INPUT_NEURONS; inp++) {
  163.                 wih[inp][hid] += LEARN_RATE * errh[hid] * inputs[inp];
  164.             }
  165.             // Обновление смещения
  166.             wih[INPUT_NEURONS][hid] += LEARN_RATE * errh[hid];
  167.         }
  168.     }
  169.  
  170.     // Функция выбора индекса выходного нейрона с максимальным значением
  171.     int action(double[] vector) {
  172.         int sel = 0;
  173.         double max = vector[0];
  174.         for (int i = 1; i < OUTPUT_NEURONS; i++) {
  175.             if (vector[i] > max) {
  176.                 max = vector[i];
  177.                 sel = i;
  178.             }
  179.         }
  180.         return sel;
  181.     }
  182.  
  183.     public static void main(String[] args) {
  184.         Main wc = new Main();
  185.         wc.assignRandomWeights();
  186.  
  187.         int sampleIndex = 0;
  188.         double err;
  189.         // Обучение сети (например, 10000 итераций)
  190.         for (int step = 0; step < 10000; step++) {
  191.             sampleIndex = (sampleIndex + 1) % MAX_SAMPLES;
  192.             Sample s = wc.samples[sampleIndex];
  193.             // Задание входного вектора: температура, влажность, скорость ветра, давление
  194.             wc.inputs[0] = s.temperature;
  195.             wc.inputs[1] = s.humidity;
  196.             wc.inputs[2] = s.windSpeed;
  197.             wc.inputs[3] = s.pressure;
  198.             // Нормализация входных данных
  199.             wc.normalizeInputs();
  200.             // Копирование целевого вектора
  201.             for (int i = 0; i < OUTPUT_NEURONS; i++) {
  202.                 wc.target[i] = s.out[i];
  203.             }
  204.             wc.feedForward();
  205.             err = 0.0;
  206.             for (int i = 0; i < OUTPUT_NEURONS; i++) {
  207.                 double diff = s.out[i] - wc.actual[i];
  208.                 err += diff * diff;
  209.             }
  210.             err = 0.5 * err;
  211.             if (step % 1000 == 0) {
  212.                 System.out.println("step = " + step + " mse = " + err);
  213.             }
  214.             wc.backPropagate();
  215.         }
  216.  
  217.         System.out.println();
  218.         int correct = 0;
  219.         // Проверка сети на обучающих примерах
  220.         for (int i = 0; i < MAX_SAMPLES; i++) {
  221.             Sample s = wc.samples[i];
  222.             wc.inputs[0] = s.temperature;
  223.             wc.inputs[1] = s.humidity;
  224.             wc.inputs[2] = s.windSpeed;
  225.             wc.inputs[3] = s.pressure;
  226.             wc.normalizeInputs();
  227.             for (int j = 0; j < OUTPUT_NEURONS; j++) {
  228.                 wc.target[j] = s.out[j];
  229.             }
  230.             wc.feedForward();
  231.             int predicted = wc.action(wc.actual);
  232.             int expected = wc.action(wc.target);
  233.             if (predicted != expected) {
  234.                 System.out.println("Input: " + s.temperature + " " + s.humidity + " "
  235.                         + s.windSpeed + " " + s.pressure +
  236.                         " predicted: " + wc.conditions[predicted] +
  237.                         " expected: " + wc.conditions[expected]);
  238.             } else {
  239.                 correct++;
  240.             }
  241.         }
  242.         System.out.println("Network is " + ((float) correct / MAX_SAMPLES * 100.0) + "% correct\n");
  243.  
  244.         // Дополнительное тестирование с новыми входными данными
  245.         double[][] testInputs = {
  246.                 {30, 20, 5, 1020},   // ожидается "Солнечно"
  247.                 {22, 50, 3, 1012},   // ожидается "Облачно"
  248.                 {18, 80, 10, 1005},  // ожидается "Дождливо"
  249.                 {15, 92, 22, 993}    // ожидается "Шторм"
  250.         };
  251.         for (double[] test : testInputs) {
  252.             // Задаём тестовый входной вектор
  253.             wc.inputs[0] = test[0];
  254.             wc.inputs[1] = test[1];
  255.             wc.inputs[2] = test[2];
  256.             wc.inputs[3] = test[3];
  257.             wc.normalizeInputs();
  258.             wc.feedForward();
  259.             int index = wc.action(wc.actual);
  260.             System.out.println("Input: [" + test[0] + ", " + test[1] + ", " + test[2] + ", " + test[3]
  261.                     + "] -> " + wc.conditions[index]);
  262.         }
  263.     }
  264. }
  265.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement