Guest User

Untitled

a guest
Oct 21st, 2017
93
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.87 KB | None | 0 0
  1. import io
  2. import os
  3. import numpy as np
  4. from PIL import Image
  5. from pymongo import MongoClient
  6. from torch.utils.data import Dataset, DataLoader
  7. from torchvision import transforms
  8.  
  9.  
  10. def pil_loader(f):
  11. with Image.open(io.BytesIO(f)) as img:
  12. return img.convert('RGB')
  13.  
  14.  
  15. class DatasetDB(Dataset):
  16. def __init__(self, db_name='images', col_name='train', transform=None):
  17. self._label_dtype = np.int32
  18. self.transform = transform
  19.  
  20. client = MongoClient('localhost', 27017)
  21. db = client[db_name]
  22. self.col = db[col_name]
  23. self.examples = list(self.col.find({}, {'imgs': 0}))
  24. self.labels = self.get_labels()
  25. print(self.labels)
  26. # self.labels = dict([(line.strip(), idx) for idx, line in enumerate(open(labels_txt, "r"))])
  27.  
  28. def __len__(self):
  29. return len(self.examples)
  30.  
  31. def get_labels(self):
  32. category_ids = [e['category_id'] for e in self.examples]
  33. return {cid: i for i, cid in enumerate(sorted(list(set(category_ids))))}
  34.  
  35. def __getitem__(self, i):
  36. _id = self.examples[i]['_id']
  37. doc = self.col.find_one({'_id': _id})
  38.  
  39. img = doc['imgs'][0]['picture']
  40. img = pil_loader(img)
  41.  
  42. if self.transform:
  43. img = self.transform(img)
  44.  
  45. label = self.labels[doc['category_id']]
  46. assert type(label) == int
  47. return img, label, _id
  48.  
  49. #normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
  50. # std=[0.229, 0.224, 0.225])
  51. #
  52. #transform = transforms.Compose([
  53. # transforms.RandomSizedCrop(224),
  54. # transforms.RandomHorizontalFlip(),
  55. # transforms.ToTensor(),
  56. # ])
  57. #
  58. #dataset = DatasetDB(transform=transform)
  59. #loader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=8)
  60. #for image, label, prod in loader:
  61. # print(image.max())
  62. # print(image.min())
  63. # print(" --- ")
Add Comment
Please, Sign In to add comment