1. Package load
2. 데이터셋 다운로드 및 훈련, 검증, 테스트 데이터셋 구성
3. 하이퍼파라미터 세팅
4. Dataset 및 DataLoader 할당
5. 네트워크 설계
6. train, validation, test 함수 정의
7. 모델 저장 함수 정의
8. 모델 생성 및 Loss function, Optimizer 정의
9. Training
10. 저장된 모델 불러오기 및 test
11. Transfer Learning
10. 저장된 모델 불러오기 및 test
학습한 모델의 성능을 테스트합니다. 저장한 모델 파일을 torch.load를 통해 불러옵니다.
이렇게 불러오면 우리가 얻게 되는 건 아까 저장한 check_point 딕셔너리입니다. 딕셔너리에 저장한 모델의 파라미터는 'net' key에 저장해두었습니다. 이를 불러와 state_dict에 저장합니다. 이렇게 불러온 모델의 파라미터를 모델에 실제로 로드하기 위해서는 nn.Module.load_state_dict를 사용하면 됩니다.
- model_path의 경로에 있는 모델 파일을 로드하여, 이를 check_point 변수에 저장합니다. 또한, 미리저장된 모델이 GPU로 학습했는데 CPU 로 불러올 경우, 파이토치에서 모델을 불러오는 함수에 map_location 인자에 device 정보를 전달해야 적용됩니다. (즉, 함수에 map_location=device 가 되어야 합니다.) device 변수는 <1. Package load> 에서 이미 선언했습니다.
- check_point 딕셔너리에 접근하여 모델의 파라미터를 state_dict 변수에 저장합니다. 접근을 위한 딕셔너리의 키값은 'net' 입니다.
- state_dict의 파라미터들을 새로 선언한 모델(model)에 로드합니다.
11. Transfer Learning
'AI > PyTorch' 카테고리의 다른 글
[PyTorch] torch.tensor.sum(), torch.tensor.mean(), torch.tensor.max() (0) | 2021.07.12 |
---|---|
[PyTorch] CNN 설계 11. Transfer Learning (0) | 2021.07.01 |
[PyTorch] torch.save(model.state_dict(), PATH) / torch.load(PATH) (0) | 2021.06.29 |
[PyTorch] torch.tensor와 torch (0) | 2021.06.29 |
[PyTorch] torch.tensor.float(), torch.mean() (0) | 2021.06.29 |