1. Package load
2. 데이터셋 다운로드 및 훈련, 검증, 테스트 데이터셋 구성
3. 하이퍼파라미터 세팅
4. Dataset 및 DataLoader 할당
5. 네트워크 설계
6. train, validation, test 함수 정의
7. 모델 저장 함수 정의
8. 모델 생성 및 Loss function, Optimizer 정의
9. Training
10. 저장된 모델 불러오기 및 test
11. Transfer Learning
모델을 만드는 클래스를 정의한다. 이때 nn.Module을 상속한다.
init으로 sequential layer를 만들어준다.
init으로 만든 layer를 forward에서 바로 이어서 진행하게 된다.
5. 네트워크 설계
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
# self.conv 구현
self.conv = nn.Sequential(
## 코드 시작 ##
torch.nn.Conv2d(3, 32, kernel_size=3), # conv_1 해당하는 층
torch.nn.BatchNorm2d(32), # batch_norm_1 해당하는 층
torch.nn.ReLU(), # ReLU_1 해당하는 층
torch.nn.MaxPool2d(kernel_size=2), # maxpool_1 해당하는 층
nn.Conv2d(32, 64, kernel_size=3), # conv_2 해당하는 층
nn.BatchNorm2d(64), # batch_norm_2 해당하는 층
nn.ReLU(), # ReLU_2 해당하는 층
nn.MaxPool2d(kernel_size=2), # maxpool_2 해당하는 층
nn.Conv2d(64, 128, kernel_size=3), # conv_3 해당하는 층
nn.BatchNorm2d(128), # batch_norm_3 해당하는 층
nn.ReLU(), # ReLU_3 해당하는 층
nn.MaxPool2d(kernel_size=2), # maxpool_3 해당하는 층
nn.Conv2d(128, 128, kernel_size=3), # conv_4 해당하는 층
nn.BatchNorm2d(128), # batch_norm_4 해당하는 층
nn.ReLU(), # ReLU_4 해당하는 층
nn.MaxPool2d(kernel_size=2), # maxpool_4 해당하는 층
## 코드 종료 ##
)
# self.fc 구현
## 코드 시작 ##
self.fc1 = nn.Linear(128*5*5, 512)
self.fc2 = nn.Linear(512,2)
## 코드 종료 ##
def forward(self, x):
x = self.conv(x)
x = x.view(x.shape[0], -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
conv 계산을 통해 2d의 결과가 나오게 된다. 이것을 fc에 들어가기 전에 1차원 vector로 쭉 펴준다. 제일 앞의 값은 batch size이다.
DataLoader의 입력 데이터의 shape은 (batch_size, 1, 28, 28)입니다. 하지만 MLP의 입력은 (batch_size, 입력 feature 수)가 되어야 합니다. 따라서 우리는 Loader에서 나온 텐서의 shape을 이와 같은 형태로 변형해주어야 합니다. 이러한 변형을 flatten 이라고 부르기도 합니다. flatten 과정은 forward 함수의 x = x.view(x.size(0), -1) 로 구현했습니다.
'AI > PyTorch' 카테고리의 다른 글
[PyTorch] CNN 설계 7 - 9 (0) | 2021.06.15 |
---|---|
[PyTorch] CNN 설계 6. train, validation, test 함수 정의 (0) | 2021.06.15 |
[PyTorch] CNN 설계 4. Dataset 및 DataLoader 할당 (0) | 2021.06.15 |
[PyTorch] torch.optim의 인자 model.parameters() (0) | 2021.06.15 |
[PyTorch] torch.nn 제공함수 (0) | 2021.06.15 |