
我正在使用TensorDataset从numpy数组创建数据集.
# convert numpy arrays to pytorch tensorsX_train = torch.stack([torch.from_numpy(np.array(i)) for i in X_train])y_train = torch.stack([torch.from_numpy(np.array(i)) for i in y_train])# reshape into [C,H,W]X_train = X_train.reshape((-1,1,28,28)).float()# create dataset and DataLoaderstrain_dataset = torch.utils.data.TensorDataset(X_train,y_train)train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=64)如何将数据扩充(transforms)应用于TensorDataset?
例如,使用ImageFolder,我可以将转换指定为torchvision.datasets.ImageFolder(root,transform = …)的参数之一.
根据PyTorch团队成员之一的this reply,默认情况下不支持.有其他替代方法吗?
随意询问是否需要更多代码来解释问题.
最佳答案默认情况下,TensorDataset不支持变换.但是我们可以创建我们的自定义类来添加该选项.但是,正如我已经提到的,大多数转换都是为PIL.Image开发的.但是无论如何,这里是带有非常虚拟转换的非常简单的MNIST示例.带有MNIST here的csv文件.码:
import numpy as npimport torchfrom torch.utils.data import Dataset,TensorDatasetimport torchvisionimport torchvision.transforms as transformsimport matplotlib.pyplot as plt# import mnist dataset from cvs file and convert it to torch tensorwith open('mnist_train.csv','r') as f: mnist_train = f.readlines()# ImagesX_train = np.array([[float(j) for j in i.strip().split(',')][1:] for i in mnist_train])X_train = X_train.reshape((-1,28))X_train = torch.tensor(X_train)# Labelsy_train = np.array([int(i[0]) for i in mnist_train])y_train = y_train.reshape(y_train.shape[0],1)y_train = torch.tensor(y_train)del mnist_trainclass CustomTensorDataset(Dataset): """TensorDataset with support of transforms. """ def __init__(self,tensors,transform=None): assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors) self.tensors = tensors self.transform = transform def __getitem__(self,index): x = self.tensors[0][index] if self.transform: x = self.transform(x) y = self.tensors[1][index] return x,y def __len__(self): return self.tensors[0].size(0)def imshow(img,Title=''): """Plot the image batch. """ plt.figure(figsize=(10,10)) plt.Title(Title) plt.imshow(np.transpose( img.numpy(),(1,2,0)),cmap='gray') plt.show()# Dataset w/o any tranformationstrain_dataset_normal = CustomTensorDataset(tensors=(X_train,y_train),transform=None)train_loader = torch.utils.data.DataLoader(train_dataset_normal,batch_size=16)# iteratefor i,data in enumerate(train_loader): x,y = data imshow(torchvision.utils.make_grID(x,4),Title='normal') break # we need just one batch# Let's add some transforms# Dataset with flipPing tranformationsdef vflip(tensor): """Flips tensor vertically. """ tensor = tensor.flip(1) return tensordef hflip(tensor): """Flips tensor horizontally. """ tensor = tensor.flip(2) return tensortrain_dataset_vf = CustomTensorDataset(tensors=(X_train,transform=vflip)train_loader = torch.utils.data.DataLoader(train_dataset_vf,batch_size=16)result = []for i,Title='Vertical flip') breaktrain_dataset_hf = CustomTensorDataset(tensors=(X_train,transform=hflip)train_loader = torch.utils.data.DataLoader(train_dataset_hf,Title='Horizontal flip') break输出:
总结
以上是内存溢出为你收集整理的python-TensorDataset上的PyTorch转换 全部内容,希望文章能够帮你解决python-TensorDataset上的PyTorch转换 所遇到的程序开发问题。
如果觉得内存溢出网站内容还不错,欢迎将内存溢出网站推荐给程序员好友。
欢迎分享,转载请注明来源:内存溢出
微信扫一扫
支付宝扫一扫
评论列表(0条)