Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # 读取图片和标签,进行数据输入
- # 适合csv格式 图片名称 + 标签
- from __future__ import print_function, division
- import os
- import torch
- import pandas as pd
- from skimage import io, transform
- import numpy as np
- import matplotlib.pyplot as plt
- from torch.utils.data import Dataset, DataLoader
- from torchvision import transforms, utils
- class ClassifyDataset(Dataset):
- """分类图像数据集"""
- def __init__(self, csv_file, root_dir, transform=None):
- """
- Args:
- csv_file (string): 带有标记点的csv文件路径
- root_dir (string): 图片路径
- transform (callable, optional):可选择进行的图像变换
- """
- self.landmarks_frame = pd.read_csv(csv_file)
- self.root_dir = root_dir
- self.transform = transform
- def __len__(self):
- return len(self.landmarks_frame)
- def __getitem__(self, idx):
- img_name = os.path.join(self.root_dir,
- self.landmarks_frame.iloc[idx, 0])
- image = io.imread(img_name)
- label = self.landmarks_frame.iloc[idx, 1]
- label = label.astype('float')
- sample = {'image': image, 'label': label}
- if self.transform:
- sample = self.transform(sample)
- return sample
- # 定义transform的操作
- class Rescale(object):
- """按照给定尺寸更改一个图像的尺寸
- Args:
- output_size (tuple or int): 要求输出的尺寸. 如果是个元组类型, 输出
- 和output_size匹配. 如果是int类型,图片的短边和output_size匹配, 图片的
- 长宽比保持不变.
- """
- def __init__(self, output_size):
- assert isinstance(output_size, (int, tuple))
- self.output_size = output_size
- def __call__(self, sample):
- image, label = sample['image'], sample['label']
- h, w = image.shape[:2]
- if isinstance(self.output_size, int):
- if h > w:
- new_h, new_w = self.output_size * h / w, self.output_size
- else:
- new_h, new_w = self.output_size, self.output_size * w / h
- else:
- new_h, new_w = self.output_size
- new_h, new_w = int(new_h), int(new_w)
- img = transform.resize(image, (new_h, new_w))
- return {'image': img, 'label': label}
- class RandomCrop(object):
- """随机裁剪图片
- Args:
- output_size (tuple or int): 期望输出的尺寸, 如果是int类型, 裁切成正方形.
- """
- def __init__(self, output_size):
- assert isinstance(output_size, (int, tuple))
- if isinstance(output_size, int):
- self.output_size = (output_size, output_size)
- else:
- assert len(output_size) == 2
- self.output_size = output_size
- def __call__(self, sample):
- image, label = sample['image'], sample['label']
- h, w = image.shape[:2]
- new_h, new_w = self.output_size
- top = np.random.randint(0, h - new_h)
- left = np.random.randint(0, w - new_w)
- image = image[top: top + new_h,
- left: left + new_w]
- return {'image': image, 'label': label}
- class ToTensor(object):
- """将ndarrays的样本转化为Tensors的样本"""
- def __call__(self, sample):
- image, label = sample['image'], sample['label']
- # 矩阵变换,交换颜色通道, 因为
- # numpy图片: H x W x C
- # torch图片 : C X H X W
- image = image.transpose((2, 0, 1))
- return {'image': torch.from_numpy(image),
- 'label': torch.from_numpy(label)}
- data_transform=transforms.Compose([Rescale(256),
- RandomCrop(224),
- ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
- transformed_dataset = ClassifyDataset(csv_file='img_label.csv',
- root_dir='imgs/',
- transform=data_transform)
- dataloader = DataLoader(transformed_dataset, batch_size=16,
- shuffle=True, num_workers=4)
- for num_batch, batch_item in enumerate(dataloader):
- images_batch = batch_item['image']
- labels_batch = batch_item['label']
- print(images_batch.size())
- print(labels_batch.size())
- batch_size = len(images_batch)
- im_size = images_batch.size(2)
Add Comment
Please, Sign In to add comment