본문 바로가기

AI/PyTorch

[PyTorch] CNN 설계 10.저장된 모델 불러오기 및 test

1. Package load 
2. 데이터셋 다운로드 및 훈련, 검증, 테스트 데이터셋 구성 
3. 하이퍼파라미터 세팅 
4. Dataset 및 DataLoader 할당 
5. 네트워크 설계 
6. train, validation, test 함수 정의 
7. 모델 저장 함수 정의 
8. 모델 생성 및 Loss function, Optimizer 정의 
9. Training 
10. 저장된 모델 불러오기 및 test 
11. Transfer Learning

 

 


10. 저장된 모델 불러오기 및 test 

 

학습한 모델의 성능을 테스트합니다. 저장한 모델 파일을 torch.load를 통해 불러옵니다.

이렇게 불러오면 우리가 얻게 되는 건 아까 저장한 check_point 딕셔너리입니다. 딕셔너리에 저장한 모델의 파라미터는 'net' key에 저장해두었습니다. 이를 불러와 state_dict에 저장합니다. 이렇게 불러온 모델의 파라미터를 모델에 실제로 로드하기 위해서는 nn.Module.load_state_dict를 사용하면 됩니다. 

 

  1. model_path의 경로에 있는 모델 파일을 로드하여, 이를 check_point 변수에 저장합니다. 또한, 미리저장된 모델이 GPU로 학습했는데 CPU 로 불러올 경우, 파이토치에서 모델을 불러오는 함수에 map_location 인자에 device 정보를 전달해야 적용됩니다. (즉, 함수에 map_location=device 가 되어야 합니다.) device 변수는 <1. Package load> 에서 이미 선언했습니다.
  2. check_point 딕셔너리에 접근하여 모델의 파라미터를 state_dict 변수에 저장합니다. 접근을 위한 딕셔너리의 키값은 'net' 입니다.
  3. state_dict의 파라미터들을 새로 선언한 모델(model)에 로드합니다.

 

 

 

 

11. Transfer Learning