Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from random import shuffle
- import numpy as np
- import glob
- import tensorflow as tf
- import cv2
- import sys
- import os
- import PIL.Image as Image
- import io
- from random import shuffle
- def encode_utf8_string(text, length, dic, null_char_id=0):
- text = text
- char_ids_padded = [null_char_id]*length
- for i in range(len(text)):
- hash_id = dic[text[i]]
- char_ids_padded[i] = hash_id
- print(char_ids_padded)
- return char_ids_padded
- def _bytes_feature(value):
- return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
- def _int64_feature(value):
- return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
- dict={}
- with open('dic.txt', encoding="utf-8-sig") as dict_file:
- for line in dict_file:
- (key, value) = line.strip().split('\t')
- dict[value] = int(key)
- print((dict))
- image_path = 'dataset1/**.jpg'
- addrs_image = glob.glob(image_path,
- recursive = True)
- shuffle(addrs_image)
- print(len(addrs_image))
- tfrecord_writer = tf.io.TFRecordWriter("tfexample_train")
- for j in range(0,int(len(addrs_image))):
- # 这是写入操作可视化处理
- print('Train data: {}/{}'.format(j,int(len(addrs_image))))
- sys.stdout.flush()
- img = Image.open(addrs_image[j])
- np_data = np.array(img)
- imgByteArr = io.BytesIO()
- img.save(imgByteArr, format='JPEG')
- image_data = imgByteArr.getvalue()
- print("Current File: " + addrs_image[j])
- for text in open(addrs_image[j].replace('.jpg', '.txt'), encoding="utf-8-sig"):
- char_ids_padded = encode_utf8_string(
- text=("<" + text),
- dic=dict,
- length=42,
- null_char_id=0)
- for text in open(addrs_image[j].replace('.jpg', '.txt'), encoding="utf-8-sig"):
- char_ids_unpadded = encode_utf8_string(
- text=(text + ">"),
- dic=dict,
- length=42,
- null_char_id=0)
- example = tf.train.Example(features=tf.train.Features(
- feature={
- 'image/encoded': _bytes_feature(image_data),
- 'image/format': _bytes_feature(b"JPEG"),
- 'image/width': _int64_feature([np_data.shape[1]]),
- 'image/orig_width': _int64_feature([np_data.shape[1]]),
- 'image/class': _int64_feature(char_ids_padded),
- 'image/unpadded_class': _int64_feature(char_ids_unpadded),
- 'image/text': _bytes_feature(bytes(text, 'utf-8-sig')),
- }
- ))
- tfrecord_writer.write(example.SerializeToString())
- tfrecord_writer.close()
- sys.stdout.flush()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement