Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- //
- // TrainAndTest.c
- // MLCoursework
- //
- // This is a fairly inefficient implentation that does not use any dynamic memory allocation
- // because that wouldnot be safe on the DEWIS marking system
- //
- // Created by Jim Smith on 06/02/2017.
- // Copyright © 2017 Jim SmithJim Smith. All rights reserved.
- //
- #include "TrainAndTest.h"
- #include "math.h"
- #include <stdio.h>
- #include <stdlib.h>
- //declare this array as static but make it available to any function in this file
- //in case we want to store the training examples and use them later
- static double myModel[NUM_TRAINING_SAMPLES][NUM_FEATURES];
- //even if each item in the training set is from a diffferent class we know how many there are
- static char myModelLabels[NUM_TRAINING_SAMPLES];
- static int trainingSetSize = 0;
- //Swap function for Bubble Sort
- void swap(double *a, double *b) {
- double placeholder = *a;
- *a = *b;
- *b = placeholder;
- }
- //Bubble sort
- void bubbles(double array[], int n) {
- int i, j;
- for (i = 0; i < n - 1; i++)
- // Last i elements are already in place
- for (j = 0; j < n - i - 1; j++)
- if (array[j] > array[j + 1])
- swap(&array[j], &array[j + 1]);
- }//end of bubble sort
- int train(double **trainingSamples, char *trainingLabels, int numSamples, int numFeatures) {
- int returnval = 1;
- int sample, feature;
- //clean the model because C leaves whatever is in the memory |JIM|
- for (sample = 0; sample < NUM_TRAINING_SAMPLES; sample++) {
- for (feature = 0; feature<NUM_FEATURES; feature++) {
- myModel[sample][feature] = 0.0;
- }
- }
- //sanity checking |JIM|
- if (numFeatures > NUM_FEATURES || numSamples > NUM_TRAINING_SAMPLES) {
- fprintf(stdout, "error: called train with data set larger than spaced allocated to store it");
- returnval = 0;
- }
- if (returnval == 1) {
- //store the labels and the feature values
- trainingSetSize = numSamples;
- int index, feature;
- for (index = 0; index < numSamples; index++) {
- myModelLabels[index] = trainingLabels[index];
- for (feature = 0; feature < numFeatures; feature++) {
- myModel[index][feature] = trainingSamples[index][feature];
- }
- }
- fprintf(stdout, "data stored locally \n");
- }//end else
- return returnval;
- }
- char predictLabel(double *testSample, int numFeatures) {
- double duplicateArray[NUM_TRAINING_SAMPLES]; //Array to hold neighbourDistances before sorting
- double neighbourDistances[NUM_TRAINING_SAMPLES]; //Array to hold distances to neighbours
- int catA = 0, catB = 0, catC = 0; // Variables for classifying and holding the different categories of data
- int k = 9; //Number of neighbours checked
- int prediction[9];
- // square root of the sum of the squared differences between the two arrays of numbers
- for (int i = 0; i < NUM_TRAINING_SAMPLES; i++)
- neighbourDistances[i] =
- sqrt((myModel[i][0] - testSample[0]) * (myModel[i][0] - testSample[0]) +
- (myModel[i][1] - testSample[1]) * (myModel[i][1] - testSample[1]) +
- (myModel[i][2] - testSample[2]) * (myModel[i][2] - testSample[2]) +
- (myModel[i][3] - testSample[3]) * (myModel[i][3] - testSample[3]));
- //puts original array into new duplicateArray
- for (int i = 0; i < NUM_TRAINING_SAMPLES; i++)
- duplicateArray[i] = neighbourDistances[i];
- //sorts original array
- bubbles(neighbourDistances, NUM_TRAINING_SAMPLES);
- //compares original array with the sorted array
- for (int i = 0; i < k; i++) {
- for (int j = 0; j < NUM_TRAINING_SAMPLES; j++) {
- if (neighbourDistances[i] == duplicateArray[j]) {
- prediction[i] = j;
- }
- }
- }
- //prints sorted neighbours array
- printf("Sorted array: \n");
- for (int i = 0; i < k; i++) {
- printf("%lf\n", neighbourDistances[i]);
- }
- for (int i = 0; i < k; i++)
- {
- if (myModelLabels[prediction[i]] == myModelLabels[0])
- catA++;
- else if (myModelLabels[prediction[i]] > myModelLabels[0] && myModelLabels[prediction[i]] <= myModelLabels[49])
- catB++;
- else if (myModelLabels[prediction[i]] >= myModelLabels[99])
- catC++;
- }
- //Check likelihood of category and return most prominent one
- if ((catA > catB) && (catA > catC)) return myModelLabels[0];
- else if ((catB > catA) && (catB > catC)) return myModelLabels[49];
- else return myModelLabels[99];
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement