Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- require 'xlua'
- require 'sys'
- local batches_folder = '/opt/rocks/cifar.torch/cifar-10-batches-t7'
- local data = {}
- local labels = {}
- for i=1,5 do
- local name = paths.concat(batches_folder, 'data_batch_'..i..'.t7')
- local part = torch.load(paths.concat(batches_folder, name), 'ascii')
- table.insert(data, part.data:view(3,32,32,-1))
- table.insert(labels, part.labels:squeeze())
- end
- data = torch.ByteTensor.cat(data, 4)
- labels = torch.ByteTensor.cat(labels)
- test_part = torch.load(paths.concat(batches_folder, 'test_batch.t7'), 'ascii')
- test_labels = test_part.labels
- test_data = test_part.data
- local dataset = {
- trainData = {
- data = data:permute(4,1,2,3):clone(),
- labels = labels:add(1),
- size = function() return labels:numel() end,
- },
- testData = {
- data = test_data:view(3,32,32,-1):permute(4,1,2,3):clone(),
- labels = test_labels:squeeze():add(1),
- size = function() return test_labels:numel() end,
- }
- }
- print(dataset)
- print(dataset.trainData.labels:max())
- print(dataset.testData.labels:max())
- torch.save('cifar10_original.t7', dataset)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement