Pytorch学习(1)

Pytorch学习(1),第1张

关于Dataset类的重写

目的是为了构建我们所需要的数据集的相关方法

首先导入包

import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image 
import os

Dataset这个类在torch.utils.data下
PIL作为打开数据集中图片的相关库来使用,os库则是作为相关路径的提取而使用的


在Dataset这个类中,我们一共需要重写三个方法,分别是

__init__
__getitem__
__len__

__init__作为初始化函数将会为我们类提供默认的数据集路径,默认的数据集路径包含有rootdir和labeldir

    def __init__(self,rootdir,labeldir):
        self.rootdir = rootdir  #数据集的母路径
        self.labeldir = labeldir    #选用的数据集名称(label)
        self.path = os.path.join(rootdir,labeldir)  #所选用的数据集的路径
        self.img_path = os.listdir(self.path)   #数据集图片列表

其中的os.path.join(a,b)函数是可以将两端目录路径进行拼接。os.listdir()方法则是可以将某个路径下的相关文件以列表的形式进行排列,这样我们就可以通过列表的方法来访问数据集中的文件(图片)。

__getitem__方法则是通过参数item能够返回所选定的item编号图片本身和label

    def __getitem__(self, item):
        img_name = self.img_path[item]  #所选图片名称
        img_item_path = os.path.join(self.rootdir,self.labeldir,img_name)   #所选图片所在路径
        img = Image.open(img_item_path) #打开选中item编号的图片
        label = self.labeldir   #所选item编号图片的所在数据集名称(label)
        return img, label

其中Image.open()可以返回一个PIL.Image格式类型的图片

最后的len方法则是获取数据集的长度

  def __len__(self):
        return len(self.img_path)

示例:

root_dir = r'D:\pycharm_code\main_demo\dataset\hymenoptera_data\train'
ants_label_dir = r'ants'
ants_dataset = mtdata(root_dir,ants_label_dir)
bees_label_dir = r'bees'
bees_dataset = mtdata(root_dir,bees_label_dir)

注意:访问某一张图片时可以使用:

ants_dataset[x] #x是希望访问图片的标号

多个数据集之间可以通过+来进行合并

train_dataset = ants_dataset + bees_dataset


整体代码:

import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
import os

class mtdata(Dataset):
    def __init__(self,rootdir,labeldir):
        self.rootdir = rootdir  #数据集的母路径
        self.labeldir = labeldir    #选用的数据集名称(label)
        self.path = os.path.join(rootdir,labeldir)  #所选用的数据集的路径
        self.img_path = os.listdir(self.path)   #数据集图片列表
    def __getitem__(self, item):
        img_name = self.img_path[item]  #所选图片名称
        img_item_path = os.path.join(self.rootdir,self.labeldir,img_name)   #所选图片所在路径
        img = Image.open(img_item_path) #打开选中item编号的图片
        label = self.labeldir   #所选item编号图片的所在数据集名称(label)
        return img, label
    def __len__(self):
        return len(self.img_path)


root_dir = r'D:\pycharm_code\main_demo\dataset\hymenoptera_data\train'
ants_label_dir = r'ants'
ants_dataset = mtdata(root_dir,ants_label_dir)
bees_label_dir = r'bees'
bees_dataset = mtdata(root_dir,bees_label_dir)
train_dataset = ants_dataset + bees_dataset

欢迎分享,转载请注明来源:内存溢出

原文地址:https://54852.com/langs/916506.html

(0)
打赏 微信扫一扫微信扫一扫 支付宝扫一扫支付宝扫一扫
上一篇 2022-05-16
下一篇2022-05-16

发表评论

登录后才能评论

评论列表(0条)

    保存