1. Package load
2. 데이터셋 다운로드 및 훈련, 검증, 테스트 데이터셋 구성
3. 하이퍼파라미터 세팅
4. Dataset 및 DataLoader 할당
5. 네트워크 설계
6. train, validation, test 함수 정의
7. 모델 저장 함수 정의
8. 모델 생성 및 Loss function, Optimizer 정의
9. Training
10. 저장된 모델 불러오기 및 test
11. Transfer Learning
7. 모델 저장 함수 정의
모델 저장은 torch.save 함수를 통해 할 수 있습니다. nn.Module.state_dict를 통해 Module, 즉 우리 모델의 파라미터를 가져올 수 있습니다. 이렇게 불러온 파라미터를 check_point 딕셔너리에 저장합니다. 그리고 이 check_point를 정해준 경로에 저장하면 됩니다.
torch.save 는 단순히 모델의 파라미터만 저장하는 함수가 아닙니다. 어떤 파이썬 객체든 저장할 수 있습니다. 그래서 경우에 따라 check_point 딕셔너리에 모델의 파라미터 뿐만 아니라 다른 여러 가지 필요한 정보를 저장할 수도 있습니다. 예를 들어 총 몇 에폭동안 학습한 모델인지 그 정보도 저장할 수 있겠죠?

8. 모델 생성 및 Loss function, Optimizer 정의

9. Training

'AI > PyTorch' 카테고리의 다른 글
[PyTorch] torch.max(outputs, 1) (0) | 2021.06.29 |
---|---|
[PyTorch] Conv1d, Conv2d 차이점 (0) | 2021.06.29 |
[PyTorch] CNN 설계 6. train, validation, test 함수 정의 (0) | 2021.06.15 |
[PyTorch] CNN 설계 5. 네트워크 설계 (0) | 2021.06.15 |
[PyTorch] CNN 설계 4. Dataset 및 DataLoader 할당 (0) | 2021.06.15 |