Advertisement
Guest User

Untitled

a guest
Nov 19th, 2019
117
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.01 KB | None | 0 0
  1. import numpy as np
  2. import scipy as sp
  3. from numpy import random as npr
  4. import matplotlib as mpl
  5. import matplotlib.pyplot as plt
  6. import matplotlib.cm as cm
  7.  
  8. from sklearn.feature_extraction import image
  9. from sklearn.cluster import KMeans
  10. import sklearn.metrics as metrics
  11. import skimage.io
  12.  
  13. ## 1. Load the image
  14.  
  15. """
  16. def rgb_to_sat(img):
  17.   # img is an RGB image stored as given by imread()
  18. """
  19.  
  20. img = skimage.io.imread("road.jpg")
  21. img_lab = rgb_to_sat(img)
  22.  
  23. ## 2. Downscale and flatten
  24.  
  25. # get the L and b channel
  26. img_lb = img_lab[:, :, ::2]
  27. # downscale
  28. img_small = img_lb[::3,::3]
  29. # flatten
  30. img_flat = img_small.reshape([img_small.shape[0]*img_small.shape[1], 2])
  31.  
  32. ## 3. Determine the best clusters
  33.  
  34. values = []
  35. clusters = range(2, 6)
  36. for k in clusters:
  37.     # do something
  38.     est = KMeans( n_clusters=k )
  39.     est.fit( img_flat )
  40.     clusters = est.predict(img_flat)
  41.     values.append(metrics.silhouette_score(img_flat, clusters))
  42.  
  43. L = np.argsort(-np.array(values))
  44.  
  45. nclusters = L[0] + 2
  46. nclusters = int(nclusters)
  47.  
  48. ## 4. Fit cluster and plot segments
  49.  
  50. model = KMeans(n_clusters = nclusters)
  51. model.fit(img_flat)
  52. labels = model.labels_
  53. predictions = model.predict(img_flat)
  54.  
  55. img_large = img_lb.reshape([img_lb.shape[0] * img_lb.shape[1], 2])
  56. predictions_1  = model.predict(img_large)
  57.  
  58. show_segmentations(img, nclusters, predictions_1)
  59.  
  60. ## 5. Extract road markers
  61.  
  62. def extract_cluster(image, predictions, cluster_number):
  63.     """
  64.    image: Original RGB Image
  65.    predictions: A numpy array of predictions, as given by the KMeans model.
  66.    cluster_number: Which cluster number to extract.
  67.    """
  68.     pixel_labels = predictions.reshape(image.shape[:2])
  69.     img = np.zeros(image.shape)
  70.     color_index = np.where(pixel_labels == cluster_number)
  71.     img[color_index[0],color_index[1]] = image[np.where(pixel_labels == cluster_number)] / 255
  72.     return img
  73.  
  74. lanes = extract_cluster(img,predictions_1,2)
  75. height = int(lanes.shape[0] / 2)
  76. lanes[0:height] = 0
  77.  
  78. plt.figure()
  79. plt.imshow(lanes)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement