
抽象基类 data.Dataset。
class Dataset(Generic[T_co]):
r"""表示数据集的抽象类
所有数据集都是该类的子类。子类必须重写 `__getitem__` 方法,
实现通过 key 获取数据样本;子类也可以重写 `__len___` 方法,
来获取数据集的尺寸。
"""
1. 继承基类自定义数据集
# 自定义数据集类
class MyDataset(torch.utils.data.Dataset):
#
def __init__(self, *args):
super().__init__()
# 初始化数据集包含的数据和标签
pass
def __getitem__(self, index):
# 根据索引index从文件中读取一个数据
# 对数据预处理
# 返回数据即对应标签
pass
def __len__(self):
# 返回数据集的大小
return len()
示例:
本地图片保存在文件夹 “~/train_data” 中:
0.png
1.png
2.png
3.png
4.png
5.png
…
标签存放在对应的 txt 文件中:
0.png 0
1.png 0
2.png 2
3.png 1
4.png 3
5.png 2
…
创建自定义数据集类:
class MyData(torch.utils.data.Dataset):
def __init__(self, data_path, label_path):
self.data_path = data_path
self.label_path = label_path
self.imgs = []
with open(labels_path, 'r') as fp:
for line in fp:
line = line.strip()
sample = line.split(' ')
self.imgs.append((sample[0], sample[1]))
def __len__():
return len(self.imgs)
def __getitem__(self, index):
picture, label = self.imgs[index]
picture = transforms.ToTensor()(Image.open(self.data_path+'/picture'))
return picture, label
2. 创建数据集
data_path = './train'
label_path = './train_label.txt'
train_data = MyData(data_path, label_path)
print(train_data[3])
输出:
(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0039,
0.0000, 0.0000, 0.0000, 0.0000, 0.1255, 0.2392, 0.2196, 0.2549,
0.0549, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0039,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0078,
0.0000, 0.0118, 0.2745, 0.6902, 0.9137, 0.8667, 0.7725, 0.7725,
0.7882, 0.7255, 0.5137, 0.1725, 0.0000, 0.0000, 0.0039, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
... ....
[0.0000, 0.0000, 0.0000, 0.0000, 0.1137, 0.0118, 0.0000, 0.0000,
0.2431, 0.5882, 0.5255, 0.6039, 0.7059, 0.8863, 0.6980, 0.7725,
0.7373, 0.9216, 0.7686, 0.6431, 0.6275, 0.0000, 0.0000, 0.0196,
0.1020, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000,
0.0000, 0.0784, 0.0863, 0.1451, 0.2039, 0.2627, 0.3176, 0.2353,
0.2196, 0.2824, 0.1608, 0.1020, 0.0039, 0.0000, 0.0039, 0.0000,
0.0000, 0.0000, 0.0039, 0.0000]]]), tensor(4))
L = torch.utils.data.DataLoader(train_data, batch_size=50)
for X, y in L:
print(X.shape, y.shape)
break
输出:
torch.Size([50, 1, 28, 28]) torch.Size([50])
欢迎分享,转载请注明来源:内存溢出
微信扫一扫
支付宝扫一扫
评论列表(0条)