Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import requests
- import shutil
- import random
- import time
- import statistics
- import os
- from PIL import Image
- from PIL import ImageStat
- from PIL import ImageEnhance
- from os.path import isfile, join
- from selenium import webdriver
- from selenium.webdriver.common.keys import Keys
- from selenium.webdriver.common.by import By
- from selenium.webdriver.support.ui import WebDriverWait
- from selenium.webdriver.support import expected_conditions as EC
- # requires selenium, requests, statistics and pillow
- # define constants
- # image constants
- NUM_IMAGES = 7700
- NUM_COLORS = 8
- IMAGE_SIZE = 128, 128
- MAX_STANDARD_DEVIATION = 10000
- MIN_BLUE = 0.452
- # request constants
- get_zoom = lambda: random.choice(['8', '9', '10', '11'])
- BASE_URL = 'http://maps.google.com/maps/api/staticmap'
- REQUEST_DELAY = 0
- SAVE_LOCATION = 'Images'
- # ROUTER_URL = 'http://10.0.0.138/gateway.lp'
- # ROUTER_USERNAME = 'admin'
- # ROUTER_PASSWORD = 'password'
- # driver = webdriver.Chrome()
- def mean_color(colors, type='histogram'):
- # get the average color of the image
- if type == 'histogram':
- frequency = []
- for c in colors:
- frequency += [c[1]]*c[0]
- colors = frequency
- elif type != 'samples':
- raise KeyError
- # do mean for each r, g, b
- mean_color = map(statistics.mean, zip(*colors))
- return mean_color
- def color_standard_deviation(colors):
- # get the standard deviation of various colors in the image
- frequency = []
- for c in colors:
- frequency += [c[1]]*c[0]
- mean = mean_color(frequency, type='samples')
- s = []
- for color in frequency:
- # standard deviation over 3 dimensions, as opposed to 1
- s.append(sum(map(lambda x: (color[x]-mean[x])**2, [0, 1, 2])))
- stdev = sum(s)**0.5
- return stdev
- def color_percentage(color):
- # get how much each channel makes up the color as a percentage
- total = sum(color)
- try:
- c_percentage = map(lambda x: x/total, color)
- except ZeroDivisionError:
- # all channels are equal
- return 1.0/3, 1.0/3, 1.0/3
- return c_percentage
- def random_location():
- # pick a random latitude and longditude
- lat = str(round(random.uniform(-90, 89), 7))
- lon = str(round(random.uniform(-180, 179), 7))
- return lat, lon
- def format_image(img):
- # make the image suitable for processing. Reduce color palette and size.
- img = img.convert('P', palette=Image.ADAPTIVE, colors=NUM_COLORS)
- img = img.crop((0, 0)+IMAGE_SIZE)
- return img.convert('RGB')
- def eval_image(img):
- # get whether or not the image is suitable for the dataset.
- stdev = color_standard_deviation(img.getcolors())
- if stdev > MAX_STANDARD_DEVIATION:
- mean = mean_color(img.getcolors())
- r, g, b = color_percentage(mean)
- # filter out the ocean
- if b < MIN_BLUE:
- return True
- return False
- def handle_api_limit():
- # if we send to many requests to google's API they will block us.
- print '[!] API limit hit'
- time.sleep(80000)
- # I have it setup such that when the api limit is reached after a few minutes, I automate my router to reset its IP address
- # WebDriverWait(driver, 40).until(EC.presence_of_element_located((By.ID, 'Disconnect')))
- # driver.find_element_by_id('Disconnect').click()
- # WebDriverWait(driver, 20).until(EC.presence_of_element_located((By.ID, 'Connect')))
- # driver.find_element_by_id('Connect').click()
- # WebDriverWait(driver, 40).until(EC.presence_of_element_located((By.ID, 'Disconnect')))
- def post_format_images():
- # normalize the brightness for all images in dataset
- # list files, get average brightness
- files = [f for f in listdir(SAVE_LOCATION) if isfile(join(SAVE_LOCATION, f))]
- total = []
- for f in files:
- img = Image.open(os.join+f).convert('L')
- stat = ImageStat.Stat(img)
- total.append(stat.mean[0])
- avrg = statistics.mean(total)
- print '[*] Normalizing images for average brightness: {}'.format(avrg)
- # normalize brightness for each image and save in 'Normalized_SAVE_LOCATION'
- for f in files:
- img = Image.open(join(SAVE_LOCATION, f))
- gs_img = img.convert('L')
- stat = ImageStat.Stat(gs_img)
- intensity = avrg / stat.mean[0]
- enhancer = ImageEnhance.Brightness(img)
- img = enhancer.enhance(intensity)
- img = img.resize(IMAGE_SIZE)
- img.save(os.join('Normalized_{}'.format(SAVE_LOCATION), f))
- def initiate_driver(d):
- # this is for my router to automate resetting the IP address
- d.get(ROUTER_URL)
- username_box = d.find_element_by_id('srp_username')
- username_box.clear()
- username_box.send_keys(ROUTER_USERNAME)
- password_box = d.find_element_by_id('srp_password')
- password_box.clear()
- password_box.send_keys(ROUTER_PASSWORD)
- d.find_element_by_id('sign-me-in').click()
- WebDriverWait(driver, 40).until(EC.element_to_be_clickable((By.XPATH, '//u[text()=\'Internet Access\']'))).click()
- time.sleep(6)
- return True
- def main():
- # initiate_driver(driver)
- num_images = 0
- # We need NUM_IMAGES to be added to the data-set
- while num_images < NUM_IMAGES:
- time.sleep(REQUEST_DELAY)
- lat, lon = random_location()
- payload = {'scale': '1',
- 'center': '{}%2C{}'.format(lat, lon),
- 'zoom': get_zoom(),
- 'maptype': 'satellite',
- 'sensor': 'false',
- 'size': '{}x{}'.format(int(IMAGE_SIZE[0]*1.1), int(IMAGE_SIZE[0]*1.1))
- }
- r = requests.get(BASE_URL, params=payload, stream=True)
- if r.status_code == 200:
- r.raw.decode_content = True
- try:
- img = Image.open(r.raw).convert('RGB')
- except IOError:
- continue
- img = format_image(img)
- if not eval_image(img):
- # image is not suitable for dataset
- continue
- # add image to data-set
- num_images += 1
- img.save('{}/{}_{}.png'.format(SAVE_LOCATION, lat, lon))
- elif r.status_code == 403:
- handle_api_limit()
- else:
- print '[!] Recieved status code: {}'.format(r.status_code)
- # once finished getting images
- normalize_brightness()
- if __name__ == '__main__':
- try:
- main()
- except KeyboardInterrupt:
- exit()
Add Comment
Please, Sign In to add comment