1. import tensorflow as tf
2. import numpy as np
3. import pickle
4.
10.
11. train_x = [[int(x) for x in t] for t in train_x]  # 原pickle文件都是string类型
12. val_x = [[int(x) for x in v] for v in val_x]
13. test_x = [[int(x) for x in t] for t in test_x]
14.
15. subtract = lambda y: int(y) - 1
16. train_y = list(map(subtract, train_y))  # 因为预测的时候网络的输出的类别是从0开始的，最后在预测test集的时候应将预测结果都加1
17. val_y = list(map(subtract, val_y))
18.
20.
21. padding_length = 1000  # 截取文章的前800个词
22.
23. with tf.python_io.TFRecordWriter('./data/train.tfrecord') as writer:
24.     for x, y in zip(train_x, train_y):
28.         else:
32.
33.         example = tf.train.Example(features=tf.train.Features(feature={
34.             'x': tf.train.Feature(int64_list=tf.train.Int64List(value=x)),
36.             'y': tf.train.Feature(int64_list=tf.train.Int64List(value=[y]))
37.         }))
38.         tf_example = example.SerializeToString()
39.         writer.write(tf_example)
40.
41. with tf.python_io.TFRecordWriter('./data/valid.tfrecord') as writer:
42.     for x, y in zip(val_x, val_y):
46.         else:
50.
51.         example = tf.train.Example(features=tf.train.Features(feature={
52.             'x': tf.train.Feature(int64_list=tf.train.Int64List(value=x)),
54.             'y': tf.train.Feature(int64_list=tf.train.Int64List(value=[y]))
55.         }))
56.         tf_example = example.SerializeToString()
57.         writer.write(tf_example)
58.
59. with tf.python_io.TFRecordWriter('./data/test.tfrecord') as writer:
60.     for x in test_x:
64.         else:
68.
69.         example = tf.train.Example(features=tf.train.Features(feature={
70.             'x': tf.train.Feature(int64_list=tf.train.Int64List(value=x)),
72.             'y': tf.train.Feature(int64_list=tf.train.Int64List(value=[1]))
73.         }))
74.         tf_example = example.SerializeToString()
75.         writer.write(tf_example)
