본문 바로가기

AI/PyTorch

[PyTorch] CNN 설계 7 - 9

1. Package load 
2. 데이터셋 다운로드 및 훈련, 검증, 테스트 데이터셋 구성 
3. 하이퍼파라미터 세팅 
4. Dataset 및 DataLoader 할당 
5. 네트워크 설계 
6. train, validation, test 함수 정의 
7. 모델 저장 함수 정의 
8. 모델 생성 및 Loss function, Optimizer 정의 
9. Training 
10. 저장된 모델 불러오기 및 test 
11. Transfer Learning

 

7. 모델 저장 함수 정의 

 

모델 저장은 torch.save 함수를 통해 할 수 있습니다. nn.Module.state_dict를 통해 Module, 즉 우리 모델의 파라미터를 가져올 수 있습니다. 이렇게 불러온 파라미터를 check_point 딕셔너리에 저장합니다. 그리고 이 check_point를 정해준 경로에 저장하면 됩니다.

torch.save 는 단순히 모델의 파라미터만 저장하는 함수가 아닙니다. 어떤 파이썬 객체든 저장할 수 있습니다. 그래서 경우에 따라 check_point 딕셔너리에 모델의 파라미터 뿐만 아니라 다른 여러 가지 필요한 정보를 저장할 수도 있습니다. 예를 들어 총 몇 에폭동안 학습한 모델인지 그 정보도 저장할 수 있겠죠?

 

 

 

 


8. 모델 생성 및 Loss function, Optimizer 정의 

 

 

 

 


9. Training