
有的网上下载下来的数据集,数据和标签是混在一起的,不像torchvision里下载好的那样数据和标签给咱们做好分类,所以需要我们重写DataLoader这个函数。
例子:以图像分割数据集为例, 使用如下数据集重写我们的dataloader:
数据集来源:
Xiaoyong Shen, Xin Tao, Hongyun Gao, Chao Zhou, Jiaya Jia. Deep Automatic Portrait Matting. European Conference on Computer Vision (ECCV),2016.
图一 训练集中label和picture分别一个文件夹
图二 标签的文件名称 #####_matte
图三 图片
1.导入包:
import os
from torch.utils.data import DataLoader
from PIL import Image
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms
2.Mydata 类
必须要有的函数 len 和 getitem
''' 数据集处理 '''
class Mydata(Dataset):
def __init__(self, img_path, label_path):
self.img_path = img_path # 'dataset/training/picture'
self.label_path = label_path # 'dataset/training/label'
self.label_data = os.listdir(self.label_path) # 这一步是找到此文件夹下的1700张图片,找到的结果是:’00001_matte.png‘等
self.totensor = transforms.ToTensor() # 将数据改为tensor类型
self.resize1 = transforms.Resize((572, 572)) # 对数据进行裁剪
self.resize2 = transforms.Resize((388, 388))
def __len__(self):
return len(self.label_data) # 1700 返回数据的个数,有1700张
def __getitem__(self, item):
# 要保证label和picture的序号是一致的,所以用label的文件名找picture 对应起来
img_name = os.path.join(self.img_path, self.label_data[item]) # 'dataset/training/picture' +‘00001_matte.png’
img_name = os.path.split(img_name) # 'dataset/training/picture/00001_matte.png' -> ('dataset/training/picture','00001_matte.png')
img_name = img_name[-1] #'00001_matte.png'
img_name = img_name.split('_')
img_name = img_name[0] + '.png' # '00001' + ‘.png’ 就是数据的文件名
img_data = os.path.join(self.img_path, img_name) # 'dataset/training/picture/00001.png'
label_data =os.path.join(self.label_path, self.label_data[item])
''' 找到对应文件以后,就对label和picture做想做的处理'''
img = Image.open(img_data)
label = Image.open(label_data).convert('L')
img = self.totensor(img)
img = self.resize1(img)
label = self.totensor(label)
label = self.resize2(label)
label = torch.cat((label, label), dim=0)
return img, label # 返回
'''进行dataloader加载,加载后trainloader里就是处理好的(img,label)'''
train_dataset = Mydata('dataset/training/picture', 'dataset/training/label')
trainloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
欢迎分享,转载请注明来源:内存溢出
微信扫一扫
支付宝扫一扫
评论列表(0条)