Guest User

Untitled

a guest
Oct 21st, 2018
83
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.36 KB | None | 0 0
  1. # 读取图片和标签,进行数据输入
  2. # 适合csv格式 图片名称 + 标签
  3. from __future__ import print_function, division
  4. import os
  5. import torch
  6. import pandas as pd
  7. from skimage import io, transform
  8. import numpy as np
  9. import matplotlib.pyplot as plt
  10. from torch.utils.data import Dataset, DataLoader
  11. from torchvision import transforms, utils
  12.  
  13.  
  14.  
  15. class ClassifyDataset(Dataset):
  16. """分类图像数据集"""
  17.  
  18. def __init__(self, csv_file, root_dir, transform=None):
  19. """
  20. Args:
  21. csv_file (string): 带有标记点的csv文件路径
  22. root_dir (string): 图片路径
  23. transform (callable, optional):可选择进行的图像变换
  24. """
  25. self.landmarks_frame = pd.read_csv(csv_file)
  26. self.root_dir = root_dir
  27. self.transform = transform
  28.  
  29. def __len__(self):
  30. return len(self.landmarks_frame)
  31.  
  32. def __getitem__(self, idx):
  33. img_name = os.path.join(self.root_dir,
  34. self.landmarks_frame.iloc[idx, 0])
  35. image = io.imread(img_name)
  36. label = self.landmarks_frame.iloc[idx, 1]
  37. label = label.astype('float')
  38. sample = {'image': image, 'label': label}
  39.  
  40. if self.transform:
  41. sample = self.transform(sample)
  42.  
  43. return sample
  44.  
  45.  
  46. # 定义transform的操作
  47. class Rescale(object):
  48. """按照给定尺寸更改一个图像的尺寸
  49.  
  50. Args:
  51. output_size (tuple or int): 要求输出的尺寸. 如果是个元组类型, 输出
  52. 和output_size匹配. 如果是int类型,图片的短边和output_size匹配, 图片的
  53. 长宽比保持不变.
  54. """
  55.  
  56. def __init__(self, output_size):
  57. assert isinstance(output_size, (int, tuple))
  58. self.output_size = output_size
  59.  
  60. def __call__(self, sample):
  61. image, label = sample['image'], sample['label']
  62. h, w = image.shape[:2]
  63. if isinstance(self.output_size, int):
  64. if h > w:
  65. new_h, new_w = self.output_size * h / w, self.output_size
  66. else:
  67. new_h, new_w = self.output_size, self.output_size * w / h
  68. else:
  69. new_h, new_w = self.output_size
  70. new_h, new_w = int(new_h), int(new_w)
  71. img = transform.resize(image, (new_h, new_w))
  72.  
  73. return {'image': img, 'label': label}
  74.  
  75.  
  76.  
  77. class RandomCrop(object):
  78. """随机裁剪图片
  79.  
  80. Args:
  81. output_size (tuple or int): 期望输出的尺寸, 如果是int类型, 裁切成正方形.
  82. """
  83.  
  84. def __init__(self, output_size):
  85. assert isinstance(output_size, (int, tuple))
  86. if isinstance(output_size, int):
  87. self.output_size = (output_size, output_size)
  88. else:
  89. assert len(output_size) == 2
  90. self.output_size = output_size
  91.  
  92. def __call__(self, sample):
  93. image, label = sample['image'], sample['label']
  94.  
  95. h, w = image.shape[:2]
  96. new_h, new_w = self.output_size
  97.  
  98. top = np.random.randint(0, h - new_h)
  99. left = np.random.randint(0, w - new_w)
  100.  
  101. image = image[top: top + new_h,
  102. left: left + new_w]
  103.  
  104. return {'image': image, 'label': label}
  105.  
  106.  
  107.  
  108. class ToTensor(object):
  109. """将ndarrays的样本转化为Tensors的样本"""
  110.  
  111. def __call__(self, sample):
  112. image, label = sample['image'], sample['label']
  113.  
  114. # 矩阵变换,交换颜色通道, 因为
  115. # numpy图片: H x W x C
  116. # torch图片 : C X H X W
  117. image = image.transpose((2, 0, 1))
  118. return {'image': torch.from_numpy(image),
  119. 'label': torch.from_numpy(label)}
  120.  
  121.  
  122.  
  123.  
  124. data_transform=transforms.Compose([Rescale(256),
  125. RandomCrop(224),
  126. ToTensor(),
  127. transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
  128.  
  129. transformed_dataset = ClassifyDataset(csv_file='img_label.csv',
  130. root_dir='imgs/',
  131. transform=data_transform)
  132.  
  133. dataloader = DataLoader(transformed_dataset, batch_size=16,
  134. shuffle=True, num_workers=4)
  135.  
  136.  
  137. for num_batch, batch_item in enumerate(dataloader):
  138. images_batch = batch_item['image']
  139. labels_batch = batch_item['label']
  140. print(images_batch.size())
  141. print(labels_batch.size())
  142. batch_size = len(images_batch)
  143. im_size = images_batch.size(2)
Add Comment
Please, Sign In to add comment