• API
• FAQ
• Tools
• Archive
SHARE
TWEET # Untitled a guest Nov 19th, 2019 80 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
1. import numpy as np
2. import scipy as sp
3. from numpy import random as npr
4. import matplotlib as mpl
5. import matplotlib.pyplot as plt
6. import matplotlib.cm as cm
7.
8. from sklearn.feature_extraction import image
9. from sklearn.cluster import KMeans
10. import sklearn.metrics as metrics
11. import skimage.io
12.
13. ## 1. Load the image
14.
15. """
16. def rgb_to_sat(img):
17.   # img is an RGB image stored as given by imread()
18. """
19.
21. img_lab = rgb_to_sat(img)
22.
23. ## 2. Downscale and flatten
24.
25. # get the L and b channel
26. img_lb = img_lab[:, :, ::2]
27. # downscale
28. img_small = img_lb[::3,::3]
29. # flatten
30. img_flat = img_small.reshape([img_small.shape*img_small.shape, 2])
31.
32. ## 3. Determine the best clusters
33.
34. values = []
35. clusters = range(2, 6)
36. for k in clusters:
37.     # do something
38.     est = KMeans( n_clusters=k )
39.     est.fit( img_flat )
40.     clusters = est.predict(img_flat)
41.     values.append(metrics.silhouette_score(img_flat, clusters))
42.
43. L = np.argsort(-np.array(values))
44.
45. nclusters = L + 2
46. nclusters = int(nclusters)
47.
48. ## 4. Fit cluster and plot segments
49.
50. model = KMeans(n_clusters = nclusters)
51. model.fit(img_flat)
52. labels = model.labels_
53. predictions = model.predict(img_flat)
54.
55. img_large = img_lb.reshape([img_lb.shape * img_lb.shape, 2])
56. predictions_1  = model.predict(img_large)
57.
58. show_segmentations(img, nclusters, predictions_1)
59.
60. ## 5. Extract road markers
61.
62. def extract_cluster(image, predictions, cluster_number):
63.     """
64.    image: Original RGB Image
65.    predictions: A numpy array of predictions, as given by the KMeans model.
66.    cluster_number: Which cluster number to extract.
67.    """
68.     pixel_labels = predictions.reshape(image.shape[:2])
69.     img = np.zeros(image.shape)
70.     color_index = np.where(pixel_labels == cluster_number)
71.     img[color_index,color_index] = image[np.where(pixel_labels == cluster_number)] / 255
72.     return img
73.
74. lanes = extract_cluster(img,predictions_1,2)
75. height = int(lanes.shape / 2)
76. lanes[0:height] = 0
77.
78. plt.figure()
79. plt.imshow(lanes)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy.

Top