Advertisement
Guest User

Untitled

a guest
Apr 13th, 2020
264
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.88 KB | None | 0 0
  1. from random import shuffle
  2. import numpy as np
  3. import glob
  4. import tensorflow as tf
  5. import cv2
  6. import sys
  7. import os
  8. import PIL.Image as Image
  9. import io
  10. from random import shuffle
  11.  
  12.  
  13. def encode_utf8_string(text, length, dic, null_char_id=0):
  14.     text = text
  15.     char_ids_padded = [null_char_id]*length
  16.     for i in range(len(text)):
  17.         hash_id = dic[text[i]]
  18.         char_ids_padded[i] = hash_id
  19.     print(char_ids_padded)
  20.     return char_ids_padded
  21.  
  22. def _bytes_feature(value):
  23.     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
  24.  
  25. def _int64_feature(value):
  26.     return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
  27.  
  28. dict={}
  29. with open('dic.txt', encoding="utf-8-sig") as dict_file:
  30.     for line in dict_file:
  31.         (key, value) = line.strip().split('\t')
  32.         dict[value] = int(key)
  33. print((dict))
  34.  
  35. image_path = 'dataset1/**.jpg'
  36. addrs_image =  glob.glob(image_path,  
  37.                    recursive = True)
  38. shuffle(addrs_image)
  39. print(len(addrs_image))
  40.  
  41. tfrecord_writer  = tf.io.TFRecordWriter("tfexample_train")
  42. for j in range(0,int(len(addrs_image))):
  43.    
  44.  
  45.             # 这是写入操作可视化处理
  46.     print('Train data: {}/{}'.format(j,int(len(addrs_image))))
  47.     sys.stdout.flush()
  48.     img = Image.open(addrs_image[j])
  49.     np_data = np.array(img)
  50.     imgByteArr = io.BytesIO()
  51.     img.save(imgByteArr, format='JPEG')
  52.     image_data = imgByteArr.getvalue()
  53.     print("Current File: " + addrs_image[j])
  54.     for text in open(addrs_image[j].replace('.jpg', '.txt'), encoding="utf-8-sig"):
  55.                  char_ids_padded = encode_utf8_string(
  56.                             text=("<" + text),
  57.                             dic=dict,
  58.                             length=42,
  59.                             null_char_id=0)
  60.     for text in open(addrs_image[j].replace('.jpg', '.txt'), encoding="utf-8-sig"):
  61.                  char_ids_unpadded = encode_utf8_string(
  62.                             text=(text + ">"),
  63.                             dic=dict,
  64.                             length=42,
  65.                             null_char_id=0)
  66.  
  67.  
  68.     example = tf.train.Example(features=tf.train.Features(
  69.                         feature={
  70.                             'image/encoded': _bytes_feature(image_data),
  71.                             'image/format': _bytes_feature(b"JPEG"),
  72.                             'image/width': _int64_feature([np_data.shape[1]]),
  73.                             'image/orig_width': _int64_feature([np_data.shape[1]]),
  74.                             'image/class': _int64_feature(char_ids_padded),
  75.                             'image/unpadded_class': _int64_feature(char_ids_unpadded),
  76.                             'image/text': _bytes_feature(bytes(text, 'utf-8-sig')),
  77.                         }
  78.                     ))
  79.     tfrecord_writer.write(example.SerializeToString())
  80. tfrecord_writer.close()
  81.  
  82. sys.stdout.flush()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement