본문 바로가기

AI/PyTorch

[PyTorch] basic

img = transform(cv2.imread('0000.jpg')).unsqueeze(dim=0).to(device)
model = ResNet18()
model.load_state_dict(torch.load('./saved_model/model_12000_9.pth'))
print(type(img))
output = model_predict(img, model.eval(), device)
print(classes[output])

 

  • img input -> (batch size, channel, h, w)

 

  • X

model = model.load_state_dict(torch.load(PATH)) 

model = model.eval()

-> Type Error

 

  • O

model = Resnet()

model.load_state_dict(torch.load(PATH))

model.eval()