Advertisement
Guest User

bug1

a guest
Nov 21st, 2016
276
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.79 KB | None | 0 0
  1. import sys
  2. import apache_beam as beam
  3. import google.cloud.ml as ml
  4. import google.cloud.ml.io as io
  5. from google.cloud.ml.features._features import GraphOptions
  6. import tensorflow as tf
  7.  
  8.  
  9. def main(hack):
  10.     pipeline = beam.Pipeline('DirectPipelineRunner')
  11.  
  12.     csv_columns = ('id', 'image_filename',)
  13.     feature_set = {
  14.         'key': ml.features.key('id'),
  15.         # The 'image' feature is a ImageFeatureColumn that reads from the CSV's
  16.         # 'image_filename' column.
  17.         'image': ml.features.image('image_filename').image(
  18.             target_size=(256, 256)),
  19.     }
  20.  
  21.     # Mock reading in a CSV.
  22.     rows = pipeline | beam.Create([
  23.         {'id': 1, 'image_filename': 'image.jpg'}
  24.     ])
  25.  
  26.     # Run preprocess step, which generates metadata and processes the raw CSV
  27.     # rows into features.
  28.     (metadata, features) = (
  29.             rows
  30.             | 'Preprocess' >> ml.Preprocess(
  31.                 feature_set,
  32.                 input_format='csv',
  33.                 format_metadata={
  34.                     'headers': csv_columns
  35.                 }))
  36.  
  37.     # This hack makes the script work.
  38.     if hack:
  39.         metadata = metadata | beam.Map(metadata_image_hack)
  40.  
  41.     # Using metadata, create a Tensor that accepts the tf.Examples as a
  42.     # feed_dict.  This is where the script breaks.
  43.     metadata | beam.Map(create_tensor)
  44.  
  45.     pipeline.run()
  46.  
  47.  
  48. def create_tensor(metadata):
  49.     # Wrap metadata dict inside an object where the dict's keys can be accessed
  50.     # as the object's attributes.
  51.     metadata = GraphOptions(metadata)
  52.  
  53.     # Create a placeholder for tf.Example-encoded features.
  54.     placeholder = tf.placeholder(tf.string, name='input', shape=(None,))
  55.  
  56.     # Parse the tf.Example-encoded features into a dict mapping feature names
  57.     # to feature Tensors, each with a dtype and shape.
  58.     features = ml.features.FeatureMetadata.parse_features(metadata, placeholder)
  59.  
  60.     # Get the feature Tensor for the 'image' feature.  This should contain a
  61.     # batch of JPEG strings (1 per example).
  62.     image_feature = features['image']
  63.     batch_size = image_feature.get_shape()[0]
  64.  
  65.     # For each JPEG string in the batch, decode the jpeg from the feature
  66.     # Tensor.
  67.     #
  68.     # Since the 'image' feature Tensor is given a shape (None, >1) by
  69.     # FeatureMetadata.parse_features, this is impossible.
  70.     imgs = tf.map_fn(lambda x: tf.image.decode_jpeg(tf.reshape(x, [])),
  71.             image_feature, dtype=tf.uint8)
  72.  
  73.     return imgs
  74.  
  75.  
  76. def metadata_image_hack(metadata_raw):
  77.     """This is a hack to fix a bug in Google ML's ImageTransform:
  78.    _transforms.ImageTransform.transform() serializes images to string-encoded
  79.    JPEG/PNG scalars, but sets _transforms.ImageTransform.feature_size to (8 *
  80.    target_size[0] * target_size[1]).  But when deserializing (in
  81.    _features.FeatureMetadata.parse_features()), we need feature_size to be 1,
  82.    or else parse_features breaks.
  83.  
  84.    Here, we replace the metadata value to specify the image size as 1 instead
  85.    of >1 so that deserialization works.
  86.  
  87.    """
  88.     import copy
  89.  
  90.     image_col_names = set()
  91.     for col_name, col_attrs in metadata_raw['columns'].items():
  92.         if col_attrs['type'] == 'image' and col_attrs['transform'] == 'image':
  93.             image_col_names.add(col_name)
  94.  
  95.     metadata = copy.deepcopy(metadata_raw)
  96.     for feat_name, feat_attrs in metadata['features'].items():
  97.         # if feature includes image column name:
  98.         if len(set(feat_attrs['columns']) & image_col_names) > 0:
  99.             print("changing size for image feature: %s" % feat_name)
  100.             metadata['features'][feat_name]['size'] = 1
  101.  
  102.     return metadata
  103.  
  104.  
  105. if __name__ == "__main__":
  106.     if len(sys.argv) > 1 and sys.argv[1] == '--hack':
  107.         hack = True
  108.     else:
  109.         hack = False
  110.     main(hack)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement