
关于Dataset类的重写
目的是为了构建我们所需要的数据集的相关方法
首先导入包
import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
import os
Dataset这个类在torch.utils.data下
PIL作为打开数据集中图片的相关库来使用,os库则是作为相关路径的提取而使用的
在Dataset这个类中,我们一共需要重写三个方法,分别是
__init__
__getitem__
__len__
__init__作为初始化函数将会为我们类提供默认的数据集路径,默认的数据集路径包含有rootdir和labeldir
def __init__(self,rootdir,labeldir):
self.rootdir = rootdir #数据集的母路径
self.labeldir = labeldir #选用的数据集名称(label)
self.path = os.path.join(rootdir,labeldir) #所选用的数据集的路径
self.img_path = os.listdir(self.path) #数据集图片列表
其中的os.path.join(a,b)函数是可以将两端目录路径进行拼接。os.listdir()方法则是可以将某个路径下的相关文件以列表的形式进行排列,这样我们就可以通过列表的方法来访问数据集中的文件(图片)。
__getitem__方法则是通过参数item能够返回所选定的item编号图片本身和label
def __getitem__(self, item):
img_name = self.img_path[item] #所选图片名称
img_item_path = os.path.join(self.rootdir,self.labeldir,img_name) #所选图片所在路径
img = Image.open(img_item_path) #打开选中item编号的图片
label = self.labeldir #所选item编号图片的所在数据集名称(label)
return img, label
其中Image.open()可以返回一个PIL.Image格式类型的图片
最后的len方法则是获取数据集的长度
def __len__(self):
return len(self.img_path)
示例:
root_dir = r'D:\pycharm_code\main_demo\dataset\hymenoptera_data\train'
ants_label_dir = r'ants'
ants_dataset = mtdata(root_dir,ants_label_dir)
bees_label_dir = r'bees'
bees_dataset = mtdata(root_dir,bees_label_dir)
注意:访问某一张图片时可以使用:
ants_dataset[x] #x是希望访问图片的标号
多个数据集之间可以通过+来进行合并
train_dataset = ants_dataset + bees_dataset
整体代码:
import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
import os
class mtdata(Dataset):
def __init__(self,rootdir,labeldir):
self.rootdir = rootdir #数据集的母路径
self.labeldir = labeldir #选用的数据集名称(label)
self.path = os.path.join(rootdir,labeldir) #所选用的数据集的路径
self.img_path = os.listdir(self.path) #数据集图片列表
def __getitem__(self, item):
img_name = self.img_path[item] #所选图片名称
img_item_path = os.path.join(self.rootdir,self.labeldir,img_name) #所选图片所在路径
img = Image.open(img_item_path) #打开选中item编号的图片
label = self.labeldir #所选item编号图片的所在数据集名称(label)
return img, label
def __len__(self):
return len(self.img_path)
root_dir = r'D:\pycharm_code\main_demo\dataset\hymenoptera_data\train'
ants_label_dir = r'ants'
ants_dataset = mtdata(root_dir,ants_label_dir)
bees_label_dir = r'bees'
bees_dataset = mtdata(root_dir,bees_label_dir)
train_dataset = ants_dataset + bees_dataset
欢迎分享,转载请注明来源:内存溢出
微信扫一扫
支付宝扫一扫
评论列表(0条)