• Sign Up
• Login
• API
• FAQ
• Tools
• Archive
daily pastebin goal
59%
SHARE
TWEET

# Untitled

a guest Dec 13th, 2018 47 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
1. import tensorflow as tf
2. import numpy as np
3. import pickle
4.
5. train_x = pickle.load(open('data/train_x.pkl', 'rb'))
6. train_y = pickle.load(open('data/train_y.pkl', 'rb'))
7. val_x = pickle.load(open('data/val_x.pkl', 'rb'))
8. val_y = pickle.load(open('data/val_y.pkl', 'rb'))
9. test_x = pickle.load(open('data/test_x.pkl', 'rb'))
10.
11. train_x = [[int(x) for x in t] for t in train_x]  # 原pickle文件都是string类型
12. val_x = [[int(x) for x in v] for v in val_x]
13. test_x = [[int(x) for x in t] for t in test_x]
14.
15. subtract = lambda y: int(y) - 1
16. train_y = list(map(subtract, train_y))  # 因为预测的时候网络的输出的类别是从0开始的，最后在预测test集的时候应将预测结果都加1
17. val_y = list(map(subtract, val_y))
18.
19. PAD = 1280000  # 将train集合和valid集合都id都统计了一边，总共是0～1279999，因此用1280000作为padding操作的id
20.
21. padding_length = 1000  # 截取文章的前800个词
22.
23. with tf.python_io.TFRecordWriter('./data/train.tfrecord') as writer:
24.     for x, y in zip(train_x, train_y):
25.         if len(x) > padding_length:
26.             x = x[:padding_length]
27.             mask = np.ones_like(x)
28.         else:
29.             mask = np.ones_like(x)
30.             x = np.pad(x, (0, padding_length - len(x)), mode='constant', constant_values=(0, PAD))
31.             mask = np.pad(mask, (0, padding_length - len(mask)), mode='constant', constant_values=(0, 0))
32.
33.         example = tf.train.Example(features=tf.train.Features(feature={
34.             'x': tf.train.Feature(int64_list=tf.train.Int64List(value=x)),
35.             'mask': tf.train.Feature(int64_list=tf.train.Int64List(value=mask)),
36.             'y': tf.train.Feature(int64_list=tf.train.Int64List(value=[y]))
37.         }))
38.         tf_example = example.SerializeToString()
39.         writer.write(tf_example)
40.
41. with tf.python_io.TFRecordWriter('./data/valid.tfrecord') as writer:
42.     for x, y in zip(val_x, val_y):
43.         if len(x) > padding_length:
44.             x = x[:padding_length]
45.             mask = np.ones_like(x)
46.         else:
47.             mask = np.ones_like(x)
48.             x = np.pad(x, (0, padding_length - len(x)), mode='constant', constant_values=(0, PAD))
49.             mask = np.pad(mask, (0, padding_length - len(mask)), mode='constant', constant_values=(0, 0))
50.
51.         example = tf.train.Example(features=tf.train.Features(feature={
52.             'x': tf.train.Feature(int64_list=tf.train.Int64List(value=x)),
53.             'mask': tf.train.Feature(int64_list=tf.train.Int64List(value=mask)),
54.             'y': tf.train.Feature(int64_list=tf.train.Int64List(value=[y]))
55.         }))
56.         tf_example = example.SerializeToString()
57.         writer.write(tf_example)
58.
59. with tf.python_io.TFRecordWriter('./data/test.tfrecord') as writer:
60.     for x in test_x:
61.         if len(x) > padding_length:
62.             x = x[:padding_length]
63.             mask = np.ones_like(x)
64.         else:
65.             mask = np.ones_like(x)
66.             x = np.pad(x, (0, padding_length - len(x)), mode='constant', constant_values=(0, PAD))
67.             mask = np.pad(mask, (0, padding_length - len(mask)), mode='constant', constant_values=(0, 0))
68.
69.         example = tf.train.Example(features=tf.train.Features(feature={
70.             'x': tf.train.Feature(int64_list=tf.train.Int64List(value=x)),
71.             'mask': tf.train.Feature(int64_list=tf.train.Int64List(value=mask)),
72.             'y': tf.train.Feature(int64_list=tf.train.Int64List(value=[1]))
73.         }))
74.         tf_example = example.SerializeToString()
75.         writer.write(tf_example)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy.

Top