SHARE
TWEET

Untitled

a guest Nov 19th, 2019 80 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import numpy as np
  2. import scipy as sp
  3. from numpy import random as npr
  4. import matplotlib as mpl
  5. import matplotlib.pyplot as plt
  6. import matplotlib.cm as cm
  7.  
  8. from sklearn.feature_extraction import image
  9. from sklearn.cluster import KMeans
  10. import sklearn.metrics as metrics
  11. import skimage.io
  12.  
  13. ## 1. Load the image
  14.  
  15. """
  16. def rgb_to_sat(img):
  17.   # img is an RGB image stored as given by imread()
  18. """
  19.  
  20. img = skimage.io.imread("road.jpg")
  21. img_lab = rgb_to_sat(img)
  22.  
  23. ## 2. Downscale and flatten
  24.  
  25. # get the L and b channel
  26. img_lb = img_lab[:, :, ::2]
  27. # downscale
  28. img_small = img_lb[::3,::3]
  29. # flatten
  30. img_flat = img_small.reshape([img_small.shape[0]*img_small.shape[1], 2])
  31.  
  32. ## 3. Determine the best clusters
  33.  
  34. values = []
  35. clusters = range(2, 6)
  36. for k in clusters:
  37.     # do something
  38.     est = KMeans( n_clusters=k )
  39.     est.fit( img_flat )
  40.     clusters = est.predict(img_flat)
  41.     values.append(metrics.silhouette_score(img_flat, clusters))
  42.  
  43. L = np.argsort(-np.array(values))
  44.  
  45. nclusters = L[0] + 2
  46. nclusters = int(nclusters)
  47.  
  48. ## 4. Fit cluster and plot segments
  49.  
  50. model = KMeans(n_clusters = nclusters)
  51. model.fit(img_flat)
  52. labels = model.labels_
  53. predictions = model.predict(img_flat)
  54.  
  55. img_large = img_lb.reshape([img_lb.shape[0] * img_lb.shape[1], 2])
  56. predictions_1  = model.predict(img_large)
  57.  
  58. show_segmentations(img, nclusters, predictions_1)
  59.  
  60. ## 5. Extract road markers
  61.  
  62. def extract_cluster(image, predictions, cluster_number):
  63.     """
  64.    image: Original RGB Image
  65.    predictions: A numpy array of predictions, as given by the KMeans model.
  66.    cluster_number: Which cluster number to extract.
  67.    """
  68.     pixel_labels = predictions.reshape(image.shape[:2])
  69.     img = np.zeros(image.shape)
  70.     color_index = np.where(pixel_labels == cluster_number)
  71.     img[color_index[0],color_index[1]] = image[np.where(pixel_labels == cluster_number)] / 255
  72.     return img
  73.  
  74. lanes = extract_cluster(img,predictions_1,2)
  75. height = int(lanes.shape[0] / 2)
  76. lanes[0:height] = 0
  77.  
  78. plt.figure()
  79. plt.imshow(lanes)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top