img = transform(cv2.imread('0000.jpg')).unsqueeze(dim=0).to(device)
model = ResNet18()
model.load_state_dict(torch.load('./saved_model/model_12000_9.pth'))
print(type(img))
output = model_predict(img, model.eval(), device)
print(classes[output])
- img input -> (batch size, channel, h, w)
- X
model = model.load_state_dict(torch.load(PATH))
model = model.eval()
-> Type Error
- O
model = Resnet()
model.load_state_dict(torch.load(PATH))
model.eval()
'AI > PyTorch' 카테고리의 다른 글
[PyTorch] DataLoader num_workers (0) | 2023.06.22 |
---|---|
[PyTorch] torch.load error (0) | 2023.03.16 |
[Pytorch] 텐서 쌓기 함수 torch.cat(), torch.stack() 비교 (0) | 2022.04.19 |
[PyTorch] torch.nn.CrossEntropyLoss() (0) | 2022.01.19 |
[PyTorch] torchvision.dataset.CoCoDetection (0) | 2022.01.17 |