Pytorch 加载自有数据集

Pytorch 加载自有数据集,第1张

Pytorch 加载自有数据集

抽象基类 data.Dataset

class Dataset(Generic[T_co]):
	r"""表示数据集的抽象类
	所有数据集都是该类的子类。子类必须重写 `__getitem__` 方法,
	实现通过 key 获取数据样本;子类也可以重写 `__len___` 方法,
	来获取数据集的尺寸。
	"""
1. 继承基类自定义数据集
# 自定义数据集类
class MyDataset(torch.utils.data.Dataset):
    # 
    def __init__(self, *args):
        super().__init__()
        # 初始化数据集包含的数据和标签
        pass
        
    def __getitem__(self, index):
        # 根据索引index从文件中读取一个数据
        # 对数据预处理
        # 返回数据即对应标签
        pass
    
    def __len__(self):
        # 返回数据集的大小
        return len()

示例:
本地图片保存在文件夹 “~/train_data” 中:
0.png
1.png
2.png
3.png
4.png
5.png

标签存放在对应的 txt 文件中:
0.png 0
1.png 0
2.png 2
3.png 1
4.png 3
5.png 2

创建自定义数据集类:

class MyData(torch.utils.data.Dataset):
    def __init__(self, data_path, label_path):
        self.data_path = data_path
        self.label_path = label_path
        self.imgs = []
    
        with open(labels_path, 'r') as fp:
            for line in fp:
                line = line.strip()
                sample = line.split(' ')
                self.imgs.append((sample[0], sample[1]))
                
    def __len__():
        return len(self.imgs)
    
    def __getitem__(self, index):
        picture, label = self.imgs[index]
        picture = transforms.ToTensor()(Image.open(self.data_path+'/picture'))
        return picture, label

2. 创建数据集
data_path = './train'
label_path = './train_label.txt'

train_data = MyData(data_path, label_path)

print(train_data[3])

输出:

(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0039,
          0.0000, 0.0000, 0.0000, 0.0000, 0.1255, 0.2392, 0.2196, 0.2549,
          0.0549, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0039,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0078,
          0.0000, 0.0118, 0.2745, 0.6902, 0.9137, 0.8667, 0.7725, 0.7725,
          0.7882, 0.7255, 0.5137, 0.1725, 0.0000, 0.0000, 0.0039, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         ... .... 
         [0.0000, 0.0000, 0.0000, 0.0000, 0.1137, 0.0118, 0.0000, 0.0000,
          0.2431, 0.5882, 0.5255, 0.6039, 0.7059, 0.8863, 0.6980, 0.7725,
          0.7373, 0.9216, 0.7686, 0.6431, 0.6275, 0.0000, 0.0000, 0.0196,
          0.1020, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000,
          0.0000, 0.0784, 0.0863, 0.1451, 0.2039, 0.2627, 0.3176, 0.2353,
          0.2196, 0.2824, 0.1608, 0.1020, 0.0039, 0.0000, 0.0039, 0.0000,
          0.0000, 0.0000, 0.0039, 0.0000]]]), tensor(4))

L = torch.utils.data.DataLoader(train_data, batch_size=50)

for X, y in L:
    print(X.shape, y.shape)
    break

输出:

torch.Size([50, 1, 28, 28]) torch.Size([50])

欢迎分享,转载请注明来源:内存溢出

原文地址:https://54852.com/langs/725956.html

(0)
打赏 微信扫一扫微信扫一扫 支付宝扫一扫支付宝扫一扫
上一篇 2022-04-26
下一篇2022-04-26

发表评论

登录后才能评论

评论列表(0条)

    保存