pytorch的squeeze()和unsqueeze

pytorch的squeeze()和unsqueeze,第1张

一,官方文档  torch.squeeze — PyTorch 1.11.0 documentationhttps://pytorch.org/docs/stable/generated/torch.squeeze.html?highlight=squeeze#torch.squeezetorch.unsqueeze — PyTorch 1.11.0 documentationhttps://pytorch.org/docs/stable/generated/torch.unsqueeze.html#torch-unsqueeze二,代码理解
import torch

################################################
#创建一个2*3的tensor
a = torch.zeros([2,3])
print(a)#tensor([[0., 0., 0.],
        #       [0., 0., 0.]])
print(a.shape)#torch.Size([2, 3])

################################################
#在第0个位置增加一个维度
a = a.unsqueeze(0)
print(a)#tensor([[[0., 0., 0.],
        #       [0., 0., 0.]]])
print(a.shape)#torch.Size([1, 2, 3])

################################################
#在第0个位置减少一个维度 前提是0处的维度大小是1
a = a.squeeze(0)
print(a)#tensor([[0., 0., 0.],
        #       [0., 0., 0.]])
print(a.shape)#torch.Size([2, 3])

################################################
#0处的维度不是1,所以不生效
a = a.squeeze(0)
print(a)#tensor([[0., 0., 0.],
        #       [0., 0., 0.]])
print(a.shape)#torch.Size([2, 3])

可以看出unsqueeze(dim)函数就是让tensor在dim处增加一个维度;

而squeeze(dim)函数就是让tensor在dim处减少一个维度;但前提是dim处的维度是1,否则squeeze(dim)函数不会生效。

欢迎分享,转载请注明来源:内存溢出

原文地址:https://54852.com/langs/942667.html

(0)
打赏 微信扫一扫微信扫一扫 支付宝扫一扫支付宝扫一扫
上一篇 2022-05-18
下一篇2022-05-18

发表评论

登录后才能评论

评论列表(0条)

    保存