python-TensorDataset上的PyTorch转换

python-TensorDataset上的PyTorch转换,第1张

概述我正在使用TensorDataset从numpy数组创建数据集.# convert numpy arrays to pytorch tensors X_train = torch.stack([torch.from_numpy(np.array(i)) for i in X_train]) y_train = torch.stack([torch.from_

我正在使用TensorDataset从numpy数组创建数据集.

# convert numpy arrays to pytorch tensorsX_train = torch.stack([torch.from_numpy(np.array(i)) for i in X_train])y_train = torch.stack([torch.from_numpy(np.array(i)) for i in y_train])# reshape into [C,H,W]X_train = X_train.reshape((-1,1,28,28)).float()# create dataset and DataLoaderstrain_dataset = torch.utils.data.TensorDataset(X_train,y_train)train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=64)

如何将数据扩充(transforms)应用于TensorDataset?

例如,使用ImageFolder,我可以将转换指定为torchvision.datasets.ImageFolder(root,transform = …)的参数之一.

根据PyTorch团队成员之一的this reply,默认情况下不支持.有其他替代方法吗?

随意询问是否需要更多代码来解释问题.

最佳答案默认情况下,TensorDataset不支持变换.但是我们可以创建我们的自定义类来添加该选项.但是,正如我已经提到的,大多数转换都是为PIL.Image开发的.但是无论如何,这里是带有非常虚拟转换的非常简单的MNIST示例.带有MNIST here的csv文件.

码:

import numpy as npimport torchfrom torch.utils.data import Dataset,TensorDatasetimport torchvisionimport torchvision.transforms as transformsimport matplotlib.pyplot as plt# import mnist dataset from cvs file and convert it to torch tensorwith open('mnist_train.csv','r') as f:    mnist_train = f.readlines()# ImagesX_train = np.array([[float(j) for j in i.strip().split(',')][1:] for i in mnist_train])X_train = X_train.reshape((-1,28))X_train = torch.tensor(X_train)# Labelsy_train = np.array([int(i[0]) for i in mnist_train])y_train = y_train.reshape(y_train.shape[0],1)y_train = torch.tensor(y_train)del mnist_trainclass CustomTensorDataset(Dataset):    """TensorDataset with support of transforms.    """    def __init__(self,tensors,transform=None):        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)        self.tensors = tensors        self.transform = transform    def __getitem__(self,index):        x = self.tensors[0][index]        if self.transform:            x = self.transform(x)        y = self.tensors[1][index]        return x,y    def __len__(self):        return self.tensors[0].size(0)def imshow(img,Title=''):    """Plot the image batch.    """    plt.figure(figsize=(10,10))    plt.Title(Title)    plt.imshow(np.transpose( img.numpy(),(1,2,0)),cmap='gray')    plt.show()# Dataset w/o any tranformationstrain_dataset_normal = CustomTensorDataset(tensors=(X_train,y_train),transform=None)train_loader = torch.utils.data.DataLoader(train_dataset_normal,batch_size=16)# iteratefor i,data in enumerate(train_loader):    x,y = data      imshow(torchvision.utils.make_grID(x,4),Title='normal')    break  # we need just one batch# Let's add some transforms# Dataset with flipPing tranformationsdef vflip(tensor):    """Flips tensor vertically.    """    tensor = tensor.flip(1)    return tensordef hflip(tensor):    """Flips tensor horizontally.    """    tensor = tensor.flip(2)    return tensortrain_dataset_vf = CustomTensorDataset(tensors=(X_train,transform=vflip)train_loader = torch.utils.data.DataLoader(train_dataset_vf,batch_size=16)result = []for i,Title='Vertical flip')    breaktrain_dataset_hf = CustomTensorDataset(tensors=(X_train,transform=hflip)train_loader = torch.utils.data.DataLoader(train_dataset_hf,Title='Horizontal flip')    break

输出:



总结

以上是内存溢出为你收集整理的python-TensorDataset上的PyTorch转换 全部内容,希望文章能够帮你解决python-TensorDataset上的PyTorch转换 所遇到的程序开发问题。

如果觉得内存溢出网站内容还不错,欢迎将内存溢出网站推荐给程序员好友。

欢迎分享,转载请注明来源:内存溢出

原文地址:https://54852.com/langs/1199640.html

(0)
打赏 微信扫一扫微信扫一扫 支付宝扫一扫支付宝扫一扫
上一篇 2022-06-04
下一篇2022-06-04

发表评论

登录后才能评论

评论列表(0条)

    保存