daily pastebin goal
59%
SHARE
TWEET

Untitled

a guest Dec 13th, 2018 47 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import tensorflow as tf
  2. import numpy as np
  3. import pickle
  4.  
  5. train_x = pickle.load(open('data/train_x.pkl', 'rb'))
  6. train_y = pickle.load(open('data/train_y.pkl', 'rb'))
  7. val_x = pickle.load(open('data/val_x.pkl', 'rb'))
  8. val_y = pickle.load(open('data/val_y.pkl', 'rb'))
  9. test_x = pickle.load(open('data/test_x.pkl', 'rb'))
  10.  
  11. train_x = [[int(x) for x in t] for t in train_x]  # 原pickle文件都是string类型
  12. val_x = [[int(x) for x in v] for v in val_x]
  13. test_x = [[int(x) for x in t] for t in test_x]
  14.  
  15. subtract = lambda y: int(y) - 1
  16. train_y = list(map(subtract, train_y))  # 因为预测的时候网络的输出的类别是从0开始的,最后在预测test集的时候应将预测结果都加1
  17. val_y = list(map(subtract, val_y))
  18.  
  19. PAD = 1280000  # 将train集合和valid集合都id都统计了一边,总共是0~1279999,因此用1280000作为padding操作的id
  20.  
  21. padding_length = 1000  # 截取文章的前800个词
  22.  
  23. with tf.python_io.TFRecordWriter('./data/train.tfrecord') as writer:
  24.     for x, y in zip(train_x, train_y):
  25.         if len(x) > padding_length:
  26.             x = x[:padding_length]
  27.             mask = np.ones_like(x)
  28.         else:
  29.             mask = np.ones_like(x)
  30.             x = np.pad(x, (0, padding_length - len(x)), mode='constant', constant_values=(0, PAD))
  31.             mask = np.pad(mask, (0, padding_length - len(mask)), mode='constant', constant_values=(0, 0))
  32.  
  33.         example = tf.train.Example(features=tf.train.Features(feature={
  34.             'x': tf.train.Feature(int64_list=tf.train.Int64List(value=x)),
  35.             'mask': tf.train.Feature(int64_list=tf.train.Int64List(value=mask)),
  36.             'y': tf.train.Feature(int64_list=tf.train.Int64List(value=[y]))
  37.         }))
  38.         tf_example = example.SerializeToString()
  39.         writer.write(tf_example)
  40.  
  41. with tf.python_io.TFRecordWriter('./data/valid.tfrecord') as writer:
  42.     for x, y in zip(val_x, val_y):
  43.         if len(x) > padding_length:
  44.             x = x[:padding_length]
  45.             mask = np.ones_like(x)
  46.         else:
  47.             mask = np.ones_like(x)
  48.             x = np.pad(x, (0, padding_length - len(x)), mode='constant', constant_values=(0, PAD))
  49.             mask = np.pad(mask, (0, padding_length - len(mask)), mode='constant', constant_values=(0, 0))
  50.  
  51.         example = tf.train.Example(features=tf.train.Features(feature={
  52.             'x': tf.train.Feature(int64_list=tf.train.Int64List(value=x)),
  53.             'mask': tf.train.Feature(int64_list=tf.train.Int64List(value=mask)),
  54.             'y': tf.train.Feature(int64_list=tf.train.Int64List(value=[y]))
  55.         }))
  56.         tf_example = example.SerializeToString()
  57.         writer.write(tf_example)
  58.  
  59. with tf.python_io.TFRecordWriter('./data/test.tfrecord') as writer:
  60.     for x in test_x:
  61.         if len(x) > padding_length:
  62.             x = x[:padding_length]
  63.             mask = np.ones_like(x)
  64.         else:
  65.             mask = np.ones_like(x)
  66.             x = np.pad(x, (0, padding_length - len(x)), mode='constant', constant_values=(0, PAD))
  67.             mask = np.pad(mask, (0, padding_length - len(mask)), mode='constant', constant_values=(0, 0))
  68.  
  69.         example = tf.train.Example(features=tf.train.Features(feature={
  70.             'x': tf.train.Feature(int64_list=tf.train.Int64List(value=x)),
  71.             'mask': tf.train.Feature(int64_list=tf.train.Int64List(value=mask)),
  72.             'y': tf.train.Feature(int64_list=tf.train.Int64List(value=[1]))
  73.         }))
  74.         tf_example = example.SerializeToString()
  75.         writer.write(tf_example)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top