【torchsummary报错】RuntimeError: Input type (torch.cuda.FloatTensor) and weight type

【torchsummary报错】RuntimeError: Input type (torch.cuda.FloatTensor) and weight type ,第1张

源代码:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torchvision.models.resnet18(pretrained=None)
model.fc = nn.Linear(512, 10)
    
summary(model, input_size=[(3, 224, 224)], batch_size=256, device="cuda")

报错为:RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

解决方法:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torchvision.models.resnet18(pretrained=None)
model.fc = nn.Linear(512, 10)

model = model.to(device)  # 加一行这个就Ok了

summary(model, input_size=[(3, 224, 224)], batch_size=256, device="cuda")

欢迎分享,转载请注明来源:内存溢出

原文地址:https://54852.com/langs/578232.html

(0)
打赏 微信扫一扫微信扫一扫 支付宝扫一扫支付宝扫一扫
上一篇 2022-04-11
下一篇2022-04-11

发表评论

登录后才能评论

评论列表(0条)

    保存