
import torch
################################################
#创建一个2*3的tensor
a = torch.zeros([2,3])
print(a)#tensor([[0., 0., 0.],
# [0., 0., 0.]])
print(a.shape)#torch.Size([2, 3])
################################################
#在第0个位置增加一个维度
a = a.unsqueeze(0)
print(a)#tensor([[[0., 0., 0.],
# [0., 0., 0.]]])
print(a.shape)#torch.Size([1, 2, 3])
################################################
#在第0个位置减少一个维度 前提是0处的维度大小是1
a = a.squeeze(0)
print(a)#tensor([[0., 0., 0.],
# [0., 0., 0.]])
print(a.shape)#torch.Size([2, 3])
################################################
#0处的维度不是1,所以不生效
a = a.squeeze(0)
print(a)#tensor([[0., 0., 0.],
# [0., 0., 0.]])
print(a.shape)#torch.Size([2, 3])
可以看出unsqueeze(dim)函数就是让tensor在dim处增加一个维度;
而squeeze(dim)函数就是让tensor在dim处减少一个维度;但前提是dim处的维度是1,否则squeeze(dim)函数不会生效。
欢迎分享,转载请注明来源:内存溢出
微信扫一扫
支付宝扫一扫
评论列表(0条)