본문 바로가기

AI/PyTorch

[PyTorch] torch.save(model.state_dict(), PATH) / torch.load(PATH)

 

 

모델 저장하기 & 불러오기 — PyTorch Tutorials 1.9.0+cu102 documentation

Note Click here to download the full example code 모델 저장하기 & 불러오기 Author: Matthew Inkawhich 번역: 박정환 이 문서에서는 PyTorch 모델을 저장하고 불러오는 다양한 방법을 제공합니다. 이 문서 전체를 다

tutorials.pytorch.kr

목차:

 

 

모델을 저장하거나 불러올 때는 3가지의 핵심 함수와 익숙해질 필요가 있습니다:

  1. torch.save: 직렬화된 객체를 디스크에 저장합니다. 이 함수는 Python의 pickle 을 사용하여 직렬화합니다 이 함수를 사용하여 모든 종류의 객체의 모델, Tensor 및 사전을 저장할 수 있습니다.
  2. torch.load: pickle을 사용하여 저장된 객체 파일들을 역직렬화하여 메모리에 올립니다. 이 함수는 데이터를 장치에 불러올 때도 사용합니다. (장치간 모델 저장하기 & 불러오기 참고)
  3. torch.nn.Module.load_state_dict: 역직렬화된 state_dict 를 사용하여 모델의 매개변수들을 불러옵니다. state_dict 에 대한 더 자세한 정보는 state_dict가 무엇인가요? 를 참고하세요.

 

1. state_dict 가 무엇인가요?

PyTorch에서 torch.nn.Module 모델의 학습 가능한 매개변수(예. 가중치와 편향)들은 모델의 매개변수에 포함되어 있습니다(model.parameters()로 접근합니다). state_dict 는 간단히 말해 각 계층을 매개변수 텐서로 매핑되는 Python 사전(dict) 객체입니다. 이 때, 학습 가능한 매개변수를 갖는 계층(합성곱 계층, 선형 계층 등) 및 등록된 버퍼들(batchnorm의 running_mean)만이 모델의 state_dict 에 항목을 가짐을 유의하시기 바랍니다. 옵티마이저 객체(torch.optim) 또한 옵티마이저의 상태 뿐만 아니라 사용된 하이퍼 매개변수(Hyperparameter) 정보가 포함된 state_dict 를 갖습니다.

state_dict 객체는 Python 사전이기 때문에 쉽게 저장하거나 갱신하거나 바꾸거나 되살릴 수 있으며, PyTorch 모델과 옵티마이저에 엄청난 모듈성(modularity)을 제공합니다.

 

 

# 모델 정의
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 모델 초기화
model = TheModelClass()

# 옵티마이저 초기화
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 모델의 state_dict 출력
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# 옵티마이저의 state_dict 출력
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

출력:

Model's state_dict:
conv1.weight     torch.Size([6, 3, 5, 5])
conv1.bias   torch.Size([6])
conv2.weight     torch.Size([16, 6, 5, 5])
conv2.bias   torch.Size([16])
fc1.weight   torch.Size([120, 400])
fc1.bias     torch.Size([120])
fc2.weight   torch.Size([84, 120])
fc2.bias     torch.Size([84])
fc3.weight   torch.Size([10, 84])
fc3.bias     torch.Size([10])

Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]

 

2. 추론(inference)를 위해 모델 저장하기 & 불러오기

state_dict 저장하기 / 불러오기 (권장)

저장하기:

torch.save(model.state_dict(), PATH)

불러오기:

model = TheModelClass(*args, **kwargs) 
model.load_state_dict(torch.load(PATH)) 
model.eval()

 

load_state_dict()  함수에는 저장된 객체의 경로가 아닌, 사전 객체를 전달해야 하는 것에 유의하세요. 따라서 저장된 state_dict 를 load_state_dict() 함수에 전달하기 전에 반드시 역직렬화를 해야 합니다.
예를 들어, model.load_state_dict(PATH) 과 같은 식으로는 사용하면 안됩니다.

 

전체 모델 저장하기/불러오기

저장하기:

torch.save(model, PATH)

불러오기:

# 모델 클래스는 어딘가에 반드시 선언되어 있어야 합니다
model = torch.load(PATH)
model.eval()

 

이 저장하기/불러오기 과정은 가장 직관적인 문법을 사용하며 적은 양의 코드를 사용합니다. 이러한 방식으로 모델을 저장하는 것은 Python의 pickle 모듈을 사용하여 전체 모듈을 저장하게 됩니다. 하지만 pickle은 모델 그 자체를 저장하지 않기 때문에 직렬화된 데이터가 모델을 저장할 때 사용한 특정 클래스 및 디렉토리 경로(구조)에 얽매인다는 것이 이 방식의 단점입니다. 대신에 클래스가 위치한 파일의 경로를 저장해두고, 불러오는 시점에 사용합니다. 이러한 이유 때문에, 만들어둔 코드를 다른 프로젝트에서 사용하거나 리팩토링 후에 다양한 이유로 동작하지 않을 수 있습니다.

 

PyTorch에서는 모델을 저장할 때 .pt 또는 .pth 확장자를 사용하는 것이 일반적인 규칙입니다.

 

 

3. 추론 / 학습 재개를 위해 일반 체크포인트(checkpoint) 저장하기 & 불러오기

추론 또는 학습 재개를 위해 일반 체크포인트를 저장할 때는 반드시 모델의 state_dict 보다 많은 것들을 저장해야 합니다. 모델이 학습을 하며 갱신되는 버퍼와 매개변수가 포함된 옵티마이저의 state_dict 도 함께 저장하는 것이 중요합니다. 그 외에도 마지막 에폭(epoch), 최근에 기록된 학습 손실, 외부 torch.nn.Embedding 계층 등도 함께 저장합니다. 결과적으로, 이런 체크포인트는 종종 모델만 저장하는 것보다 2~3배 정도 커지게 됩니다.

저장하기:

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

불러오기:

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

 

추론 또는 학습 재개를 위해 일반 체크포인트를 저장할 때는 반드시 모델의 state_dict 보다 많은 것들을 저장해야 합니다. 모델이 학습을 하며 갱신되는 버퍼와 매개변수가 포함된 옵티마이저의 state_dict 도 함께 저장하는 것이 중요합니다. 그 외에도 마지막 에폭(epoch), 최근에 기록된 학습 손실, 외부 torch.nn.Embedding 계층 등도 함께 저장합니다. 결과적으로, 이런 체크포인트는 종종 모델만 저장하는 것보다 2~3배 정도 커지게 됩니다.

여러가지를 함께 저장하려면, 사전(dictionary) 자료형으로 만든 후 torch.save() 를 사용하여 직렬화합니다. PyTorch가 이러한 체크포인트를 저장할 때는 .tar 확장자를 사용하는 것이 일반적인 규칙입니다.

항목들을 불러올 때에는 먼저 모델과 옵티마이저를 초기화한 후, torch.load() 를 사용하여 사전을 불러옵니다. 이후로는 저장된 항목들을 사전에 원하는대로 사전에 질의하여 쉽게 접근할 수 있습니다.

추론을 실행하기 전에는 반드시 model.eval() 을 호출하여 드롭아웃 및 배치 정규화를 평가 모드로 설정하여야 합니다. 이것을 하지 않으면 추론 결과가 일관성 없게 출력됩니다. 만약 학습을 계속하고 싶다면, model.train() 을 호출하여 학습 모드로 전환되도록 해야 합니다.

 

def save_model(model, saved_dir, file_name='best_model.pt'):
    os.makedirs(saved_dir, exist_ok=True)
    check_point = {
        'net': model.state_dict()
    }
    output_path = os.path.join(saved_dir, file_name)
    ## 코드 시작 ##
    torch.save(check_point, output_path)
    ## 코드 종료 ##