AI/PyTorch
[PyTorch] basic
frieden1946
2022. 11. 11. 11:45
img = transform(cv2.imread('0000.jpg')).unsqueeze(dim=0).to(device)
model = ResNet18()
model.load_state_dict(torch.load('./saved_model/model_12000_9.pth'))
print(type(img))
output = model_predict(img, model.eval(), device)
print(classes[output])
- img input -> (batch size, channel, h, w)
- X
model = model.load_state_dict(torch.load(PATH))
model = model.eval()
-> Type Error
- O
model = Resnet()
model.load_state_dict(torch.load(PATH))
model.eval()