Advertisement
Guest User

Untitled

a guest
Nov 17th, 2019
94
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.81 KB | None | 0 0
  1. def create_image_lists(image_dir, testing_percentage, validation_percentage):
  2.   """Builds a list of training images from the file system.
  3.  
  4.  Analyzes the sub folders in the image directory, splits them into stable
  5.  training, testing, and validation sets, and returns a data structure
  6.  describing the lists of images for each label and their paths.
  7.  
  8.  Args:
  9.    image_dir: String path to a folder containing subfolders of images.
  10.    testing_percentage: Integer percentage of the images to reserve for tests.
  11.    validation_percentage: Integer percentage of images reserved for validation.
  12.  
  13.  Returns:
  14.    A dictionary containing an entry for each label subfolder, with images split
  15.    into training, testing, and validation sets within each label.
  16.  """
  17.   if not gfile.Exists(image_dir):
  18.     tf.logging.error("Image directory '" + image_dir + "' not found.")
  19.     return None
  20.   result = collections.OrderedDict()
  21.   sub_dirs = [
  22.     os.path.join(image_dir,item)
  23.     for item in gfile.ListDirectory(image_dir)]
  24.   sub_dirs = sorted(item for item in sub_dirs
  25.                     if gfile.IsDirectory(item))
  26.   for sub_dir in sub_dirs:
  27.     extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
  28.     file_list = []
  29.     dir_name = os.path.basename(sub_dir)
  30.     if dir_name == image_dir:
  31.       continue
  32.     tf.logging.info("Looking for images in '" + dir_name + "'")
  33.     for extension in extensions:
  34.       file_glob = os.path.join(image_dir, dir_name, '*.' + extension)
  35.       file_list.extend(gfile.Glob(file_glob))
  36.     if not file_list:
  37.       tf.logging.warning('No files found')
  38.       continue
  39.     if len(file_list) < 20:
  40.       tf.logging.warning(
  41.           'WARNING: Folder has less than 20 images, which may cause issues.')
  42.     elif len(file_list) > MAX_NUM_IMAGES_PER_CLASS:
  43.       tf.logging.warning(
  44.           'WARNING: Folder {} has more than {} images. Some images will '
  45.           'never be selected.'.format(dir_name, MAX_NUM_IMAGES_PER_CLASS))
  46.     label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower())
  47.     training_images = []
  48.     testing_images = []
  49.     validation_images = []
  50.     for file_name in file_list:
  51.       base_name = os.path.basename(file_name)
  52.       # We want to ignore anything after '_nohash_' in the file name when
  53.       # deciding which set to put an image in, the data set creator has a way of
  54.       # grouping photos that are close variations of each other. For example
  55.       # this is used in the plant disease data set to group multiple pictures of
  56.       # the same leaf.
  57.       hash_name = re.sub(r'_nohash_.*$', '', file_name)
  58.       # This looks a bit magical, but we need to decide whether this file should
  59.       # go into the training, testing, or validation sets, and we want to keep
  60.       # existing files in the same set even if more files are subsequently
  61.       # added.
  62.       # To do that, we need a stable way of deciding based on just the file name
  63.       # itself, so we do a hash of that and then use that to generate a
  64.       # probability value that we use to assign it.
  65.       hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest()
  66.       percentage_hash = ((int(hash_name_hashed, 16) %
  67.                           (MAX_NUM_IMAGES_PER_CLASS + 1)) *
  68.                          (100.0 / MAX_NUM_IMAGES_PER_CLASS))
  69.       if percentage_hash < validation_percentage:
  70.         validation_images.append(base_name)
  71.       elif percentage_hash < (testing_percentage + validation_percentage):
  72.         testing_images.append(base_name)
  73.       else:
  74.         training_images.append(base_name)
  75.     result[label_name] = {
  76.         'dir': dir_name,
  77.         'training': training_images,
  78.         'testing': testing_images,
  79.         'validation': validation_images,
  80.     }
  81.     print("train: ", len(training_images), ", test: ", len(testing_images), ", valid: ", len(validation_images))
  82.   return result
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement