Guest User

Untitled

a guest
May 21st, 2018
73
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.22 KB | None | 0 0
  1. # -*- coding: utf-8 -*-
  2. import pickle
  3. import numpy as np
  4.  
  5.  
  6. def unpickle(file):
  7. with open(file, 'rb') as fo:
  8. dict = pickle.load(fo, encoding='latin-1')
  9. return dict
  10.  
  11. def load_pkl(files):
  12. assert(isinstance(files, list))
  13. n_files = len(files)
  14. x = np.empty((n_files, 10000, 3072), dtype=np.uint8)
  15. y = np.empty((n_files, 10000), dtype=np.uint8)
  16. for i, file in enumerate(files):
  17. d = unpickle(file)
  18. x[i] = d['data']
  19. y[i] = d['labels']
  20.  
  21. total_samples = n_files * 10000
  22. x = x.reshape(total_samples, 3, 32, 32)[:, ::-1, :, :] # RGB to BGR
  23. y = y.reshape(total_samples)
  24. return x, y
  25.  
  26. def preprocess(x, y):
  27. x = x.astype(np.float32)
  28. y = y.astype(np.int32)
  29. x *= 1. / 255
  30. return x, y
  31.  
  32. def main():
  33. train_files = ['data_batch_{}'.format(i + 1) for i in range(5)]
  34. test_files = ['test_batch']
  35.  
  36. train_x, train_y = load_pkl(train_files)
  37. train_x, train_y = preprocess(train_x, train_y)
  38. np.savez('train.npz', data=train_x, label=train_y)
  39. del train_x, train_y
  40.  
  41. test_x, test_y = load_pkl(test_files)
  42. test_x, test_y = preprocess(test_x, test_y)
  43. np.savez('test.npz', data=test_x, label=test_x)
  44. del test_x, test_y
  45.  
  46. if __name__ == '__main__':
  47. main()
Add Comment
Please, Sign In to add comment