Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- import scipy as sp
- from numpy import random as npr
- import matplotlib as mpl
- import matplotlib.pyplot as plt
- import matplotlib.cm as cm
- from sklearn.feature_extraction import image
- from sklearn.cluster import KMeans
- import sklearn.metrics as metrics
- import skimage.io
- ## 1. Load the image
- """
- def rgb_to_sat(img):
- # img is an RGB image stored as given by imread()
- """
- img = skimage.io.imread("road.jpg")
- img_lab = rgb_to_sat(img)
- ## 2. Downscale and flatten
- # get the L and b channel
- img_lb = img_lab[:, :, ::2]
- # downscale
- img_small = img_lb[::3,::3]
- # flatten
- img_flat = img_small.reshape([img_small.shape[0]*img_small.shape[1], 2])
- ## 3. Determine the best clusters
- values = []
- clusters = range(2, 6)
- for k in clusters:
- # do something
- est = KMeans( n_clusters=k )
- est.fit( img_flat )
- clusters = est.predict(img_flat)
- values.append(metrics.silhouette_score(img_flat, clusters))
- L = np.argsort(-np.array(values))
- nclusters = L[0] + 2
- nclusters = int(nclusters)
- ## 4. Fit cluster and plot segments
- model = KMeans(n_clusters = nclusters)
- model.fit(img_flat)
- labels = model.labels_
- predictions = model.predict(img_flat)
- img_large = img_lb.reshape([img_lb.shape[0] * img_lb.shape[1], 2])
- predictions_1 = model.predict(img_large)
- show_segmentations(img, nclusters, predictions_1)
- ## 5. Extract road markers
- def extract_cluster(image, predictions, cluster_number):
- """
- image: Original RGB Image
- predictions: A numpy array of predictions, as given by the KMeans model.
- cluster_number: Which cluster number to extract.
- """
- pixel_labels = predictions.reshape(image.shape[:2])
- img = np.zeros(image.shape)
- color_index = np.where(pixel_labels == cluster_number)
- img[color_index[0],color_index[1]] = image[np.where(pixel_labels == cluster_number)] / 255
- return img
- lanes = extract_cluster(img,predictions_1,2)
- height = int(lanes.shape[0] / 2)
- lanes[0:height] = 0
- plt.figure()
- plt.imshow(lanes)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement