Advertisement
Guest User

Untitled

a guest
Oct 25th, 2016
58
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.07 KB | None | 0 0
  1. require 'xlua'
  2. require 'sys'
  3.  
  4. local batches_folder = '/opt/rocks/cifar.torch/cifar-10-batches-t7'
  5.  
  6. local data = {}
  7. local labels = {}
  8.  
  9. for i=1,5 do
  10. local name = paths.concat(batches_folder, 'data_batch_'..i..'.t7')
  11. local part = torch.load(paths.concat(batches_folder, name), 'ascii')
  12. table.insert(data, part.data:view(3,32,32,-1))
  13. table.insert(labels, part.labels:squeeze())
  14. end
  15.  
  16. data = torch.ByteTensor.cat(data, 4)
  17. labels = torch.ByteTensor.cat(labels)
  18.  
  19. test_part = torch.load(paths.concat(batches_folder, 'test_batch.t7'), 'ascii')
  20. test_labels = test_part.labels
  21. test_data = test_part.data
  22.  
  23.  
  24. local dataset = {
  25. trainData = {
  26. data = data:permute(4,1,2,3):clone(),
  27. labels = labels:add(1),
  28. size = function() return labels:numel() end,
  29. },
  30. testData = {
  31. data = test_data:view(3,32,32,-1):permute(4,1,2,3):clone(),
  32. labels = test_labels:squeeze():add(1),
  33. size = function() return test_labels:numel() end,
  34. }
  35. }
  36.  
  37. print(dataset)
  38. print(dataset.trainData.labels:max())
  39. print(dataset.testData.labels:max())
  40. torch.save('cifar10_original.t7', dataset)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement