Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # -*- coding: utf-8 -*-
- import pickle
- import numpy as np
- def unpickle(file):
- with open(file, 'rb') as fo:
- dict = pickle.load(fo, encoding='latin-1')
- return dict
- def load_pkl(files):
- assert(isinstance(files, list))
- n_files = len(files)
- x = np.empty((n_files, 10000, 3072), dtype=np.uint8)
- y = np.empty((n_files, 10000), dtype=np.uint8)
- for i, file in enumerate(files):
- d = unpickle(file)
- x[i] = d['data']
- y[i] = d['labels']
- total_samples = n_files * 10000
- x = x.reshape(total_samples, 3, 32, 32)[:, ::-1, :, :] # RGB to BGR
- y = y.reshape(total_samples)
- return x, y
- def preprocess(x, y):
- x = x.astype(np.float32)
- y = y.astype(np.int32)
- x *= 1. / 255
- return x, y
- def main():
- train_files = ['data_batch_{}'.format(i + 1) for i in range(5)]
- test_files = ['test_batch']
- train_x, train_y = load_pkl(train_files)
- train_x, train_y = preprocess(train_x, train_y)
- np.savez('train.npz', data=train_x, label=train_y)
- del train_x, train_y
- test_x, test_y = load_pkl(test_files)
- test_x, test_y = preprocess(test_x, test_y)
- np.savez('test.npz', data=test_x, label=test_x)
- del test_x, test_y
- if __name__ == '__main__':
- main()
Add Comment
Please, Sign In to add comment