上上上周总结------重写Dataloader,自己的mydata代码

上上上周总结------重写Dataloader,自己的mydata代码,第1张

问题描述:

有的网上下载下来的数据集,数据和标签是混在一起的,不像torchvision里下载好的那样数据和标签给咱们做好分类,所以需要我们重写DataLoader这个函数。
例子:以图像分割数据集为例, 使用如下数据集重写我们的dataloader:

数据集来源:
Xiaoyong Shen, Xin Tao, Hongyun Gao, Chao Zhou, Jiaya Jia. Deep Automatic Portrait Matting. European Conference on Computer Vision (ECCV),2016.


图一 训练集中label和picture分别一个文件夹


图二 标签的文件名称 #####_matte


图三 图片

1.导入包:

import os
from torch.utils.data import DataLoader
from PIL import Image
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms

2.Mydata 类
必须要有的函数 len 和 getitem

''' 数据集处理 '''

class Mydata(Dataset):
    def __init__(self, img_path, label_path):
        self.img_path = img_path  #  'dataset/training/picture'
        self.label_path = label_path  # 'dataset/training/label'
        self.label_data = os.listdir(self.label_path)  # 这一步是找到此文件夹下的1700张图片,找到的结果是:’00001_matte.png‘等
        self.totensor = transforms.ToTensor()   # 将数据改为tensor类型
        self.resize1 = transforms.Resize((572, 572)) # 对数据进行裁剪
        self.resize2 = transforms.Resize((388, 388))
    def __len__(self):
        return len(self.label_data)  # 1700 返回数据的个数,有1700张
    def __getitem__(self, item):
        # 要保证label和picture的序号是一致的,所以用label的文件名找picture 对应起来
        img_name = os.path.join(self.img_path, self.label_data[item]) # 'dataset/training/picture' +‘00001_matte.png’

        img_name = os.path.split(img_name) # 'dataset/training/picture/00001_matte.png' -> ('dataset/training/picture','00001_matte.png')
        img_name = img_name[-1] #'00001_matte.png'
        img_name = img_name.split('_')
        img_name = img_name[0] + '.png'   #  '00001' + ‘.png’ 就是数据的文件名
        img_data = os.path.join(self.img_path, img_name) # 'dataset/training/picture/00001.png'
        label_data =os.path.join(self.label_path, self.label_data[item])

        ''' 找到对应文件以后,就对label和picture做想做的处理'''
        img = Image.open(img_data)
        label = Image.open(label_data).convert('L')


        img = self.totensor(img)
        img = self.resize1(img)
        label = self.totensor(label)
        label = self.resize2(label)
        label = torch.cat((label, label), dim=0)


        return img, label # 返回 

'''进行dataloader加载,加载后trainloader里就是处理好的(img,label)'''
train_dataset = Mydata('dataset/training/picture', 'dataset/training/label')
trainloader = DataLoader(train_dataset, batch_size=2, shuffle=True)

欢迎分享,转载请注明来源:内存溢出

原文地址:https://54852.com/langs/915751.html

(0)
打赏 微信扫一扫微信扫一扫 支付宝扫一扫支付宝扫一扫
上一篇 2022-05-16
下一篇2022-05-16

发表评论

登录后才能评论

评论列表(0条)

    保存