Guest User

Untitled

a guest
Apr 1st, 2016
69
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.06 KB | None | 0 0
  1. Created on 24 mars 2015
  2.  
  3. @author: Pierre.Parrend
  4. '''
  5. import unittest
  6.  
  7. from fr.unistra.mathinfo.ai.labsession1.normalization import Normalizer
  8. from fr.unistra.mathinfo.ai.labsession3.kmeans import KMeanClusterer
  9.  
  10.  
  11. class Test(unittest.TestCase):
  12.  
  13.  
  14.    def setUp(self):
  15.        pass
  16.  
  17.  
  18.    def tearDown(self):
  19.        pass
  20.  
  21.    def getDatasetSize(self, datafile):
  22.  
  23.        norm = Normalizer()
  24.        iris_data_matrix = norm.load_csv(datafile)
  25.        return len(iris_data_matrix)
  26.  
  27.    def testKMeanForcedInitialisation(self):
  28.        print("** test KMean Initalisation **")
  29.        
  30.        k = 3
  31.        datafile = "../datasets/iris.csv"
  32.        
  33.        # perform initialization
  34.        kMeanClusterer = KMeanClusterer(k, datafile)
  35.        
  36.        # check the number of clusters
  37.        clusterNumber = kMeanClusterer.getClusterNumber()
  38.        self.assertTrue(clusterNumber == k, "actual cluster number: " + str(clusterNumber) + "; expected: " + str(k))
  39.        
  40.        # check the consistency of each cluster
  41.        centroids = set()
  42.        for i in range(clusterNumber):
  43.            currentCluster = kMeanClusterer.getCluster(i)
  44.            # check centroid format
  45.            centroid = currentCluster.getCentroid()
  46.            expectedObsDimensions = 5
  47.            self.assertTrue(len(centroid) == expectedObsDimensions,
  48.                            "centroid expected to contain " + str(expectedObsDimensions) +
  49.                            " data items, has actually " + str(len(centroid)))
  50.            # check observation format
  51.            observations = currentCluster.getObservations()
  52.            self.assertTrue(len(observations) == 0, "0 observation expected per cluster, has actually "
  53.                            + str(len(observations)))
  54.            
  55.            # check all centroid are different at initialisation
  56.            new_centroid = False
  57.            if tuple(centroid) not in centroids:
  58.                new_centroid = True
  59.                centroids.add(tuple(centroid))
  60.            self.assertTrue(new_centroid, "centroid are different: " + str(new_centroid))
  61.        
  62.  
  63.    def testKMeanAssignement(self):
  64.        print("** test KMean assignement **")
  65.        
  66.        # perform initialization
  67.        k = 3
  68.        datafile = "../datasets/iris.csv"
  69.        kMeanClusterer = KMeanClusterer(k, datafile)
  70.        datasetSize = self.getDatasetSize(datafile)
  71.        
  72.        kMeanClusterer.assignement()
  73.        
  74.        obsNumber = 0
  75.        clusterNumber = kMeanClusterer.getClusterNumber()
  76.        for i in range(clusterNumber):
  77.            currentCluster = kMeanClusterer.getCluster(i)
  78.            observations = currentCluster.getObservations()
  79.            for obs in observations:
  80.                obsNumber += 1
  81.  
  82.        # check that all observations are assigned
  83.        self.assertTrue(datasetSize == obsNumber, "size of dataset: " + str(datasetSize)
  84.                        + "; current number of observations in cluster:" + str(obsNumber))
  85.        
  86.        # check that the observations are assigned to the nearest centroid
  87.        centroids = []
  88.        for i in range(clusterNumber):
  89.            currentCluster = kMeanClusterer.getCluster(i)
  90.            centroids.append(currentCluster.getCentroid())
  91.        
  92.        for i in range(clusterNumber):
  93.            currentCluster = kMeanClusterer.getCluster(i)
  94.            for obs in currentCluster.getObservations():
  95.                distance_to_centroid = kMeanClusterer.computeDistance(obs, currentCluster.getCentroid())
  96.                for j in range(len(centroids)):
  97.                    if i != j:
  98.                        dst = kMeanClusterer.computeDistance(obs, kMeanClusterer.getCluster(j).getCentroid())
  99.                        self.assertTrue(distance_to_centroid <= dst,
  100.                                        "distance to centroid of own cluster " + str(i) + ":"
  101.                                        + str(distance_to_centroid) + "; distance to centroid of cluster" + str(j) + ": " + str(dst))
  102.        
  103.  
  104. if __name__ == "__main__":
  105.    unittest.main()
Advertisement
Add Comment
Please, Sign In to add comment