Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def create_image_lists(image_dir, testing_percentage, validation_percentage):
- """Builds a list of training images from the file system.
- Analyzes the sub folders in the image directory, splits them into stable
- training, testing, and validation sets, and returns a data structure
- describing the lists of images for each label and their paths.
- Args:
- image_dir: String path to a folder containing subfolders of images.
- testing_percentage: Integer percentage of the images to reserve for tests.
- validation_percentage: Integer percentage of images reserved for validation.
- Returns:
- A dictionary containing an entry for each label subfolder, with images split
- into training, testing, and validation sets within each label.
- """
- if not gfile.Exists(image_dir):
- tf.logging.error("Image directory '" + image_dir + "' not found.")
- return None
- result = collections.OrderedDict()
- sub_dirs = [
- os.path.join(image_dir,item)
- for item in gfile.ListDirectory(image_dir)]
- sub_dirs = sorted(item for item in sub_dirs
- if gfile.IsDirectory(item))
- for sub_dir in sub_dirs:
- extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
- file_list = []
- dir_name = os.path.basename(sub_dir)
- if dir_name == image_dir:
- continue
- tf.logging.info("Looking for images in '" + dir_name + "'")
- for extension in extensions:
- file_glob = os.path.join(image_dir, dir_name, '*.' + extension)
- file_list.extend(gfile.Glob(file_glob))
- if not file_list:
- tf.logging.warning('No files found')
- continue
- if len(file_list) < 20:
- tf.logging.warning(
- 'WARNING: Folder has less than 20 images, which may cause issues.')
- elif len(file_list) > MAX_NUM_IMAGES_PER_CLASS:
- tf.logging.warning(
- 'WARNING: Folder {} has more than {} images. Some images will '
- 'never be selected.'.format(dir_name, MAX_NUM_IMAGES_PER_CLASS))
- label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower())
- training_images = []
- testing_images = []
- validation_images = []
- for file_name in file_list:
- base_name = os.path.basename(file_name)
- # We want to ignore anything after '_nohash_' in the file name when
- # deciding which set to put an image in, the data set creator has a way of
- # grouping photos that are close variations of each other. For example
- # this is used in the plant disease data set to group multiple pictures of
- # the same leaf.
- hash_name = re.sub(r'_nohash_.*$', '', file_name)
- # This looks a bit magical, but we need to decide whether this file should
- # go into the training, testing, or validation sets, and we want to keep
- # existing files in the same set even if more files are subsequently
- # added.
- # To do that, we need a stable way of deciding based on just the file name
- # itself, so we do a hash of that and then use that to generate a
- # probability value that we use to assign it.
- hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest()
- percentage_hash = ((int(hash_name_hashed, 16) %
- (MAX_NUM_IMAGES_PER_CLASS + 1)) *
- (100.0 / MAX_NUM_IMAGES_PER_CLASS))
- if percentage_hash < validation_percentage:
- validation_images.append(base_name)
- elif percentage_hash < (testing_percentage + validation_percentage):
- testing_images.append(base_name)
- else:
- training_images.append(base_name)
- result[label_name] = {
- 'dir': dir_name,
- 'training': training_images,
- 'testing': testing_images,
- 'validation': validation_images,
- }
- total = len(training_images) + len(testing_images) + len(validation_images)
- print("train: ", len(training_images)/total, ", test: ", len(testing_images)/total, ", valid: ", len(validation_images)/total)
- time.sleep(5)
- return result
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement