본문 바로가기

AI/PyTorch

[PyTorch] model의 parameter 접근하기

nn.Linear()등으로 정의한 파라미터 접근은 parameter(), named_parameters()으로 가능하다. 정확히는 layer가 모두 nn.Module()을 상속받으므로 Module에 정의되어 있는 parameter 접근 방법을 사용하면 된다.

1. torch.nn.Module.parameters()

parameter()는 layer 이름을 제외한 parameter값에 대한 iterator를 준다.

layer = torch.nn.Linear(10,3)
layer
>> Linear(in_features=10, out_features=3, bias=True)

for p in layer.parameters():
	print(p)
    
>> Parameter containing:
   tensor([[-0.2232,  0.0369, -0.2201, -0.1385, -0.1104,  0.2852,  0.0249, -0.0295,
             0.0382,  0.2847],
           [-0.0430,  0.3032,  0.1541, -0.3093,  0.1008, -0.3134, -0.1431,  0.0280,
             0.2178,  0.2094],
           [ 0.1900, -0.2386,  0.1099,  0.1769, -0.0338, -0.2079,  0.0816, -0.0180,
            -0.0182,  0.1578]], requires_grad=True)
   Parameter containing:
   tensor([ 0.1549, -0.0628,  0.1692], requires_grad=True)

 

2. torch.nn.Module.named_parameters()

named_parameters()는 (name, parameter) 조합의 tuple iterator를 준다. 이때 출력물은 named_tuple이 아니니 헷갈리지 말자.

for p in layer.named_parameters():
	print(p)
    
>> ('weight', Parameter containing:
    tensor([[-0.2232,  0.0369, -0.2201, -0.1385, -0.1104,  0.2852,  0.0249, -0.0295,
              0.0382,  0.2847],
            [-0.0430,  0.3032,  0.1541, -0.3093,  0.1008, -0.3134, -0.1431,  0.0280,
              0.2178,  0.2094],
            [ 0.1900, -0.2386,  0.1099,  0.1769, -0.0338, -0.2079,  0.0816, -0.0180,
             -0.0182,  0.1578]], requires_grad=True))
    ('bias', Parameter containing:
    tensor([ 0.1549, -0.0628,  0.1692], requires_grad=True))
    
# 이름만 출력    
for name, p in layer.named_parameters():
	print(name)
    
>> weight
   bias

 

 

모델 파라미터(parameter) 확인하기 - children()

for child in model.children():
        count += 1
        if count == 2:
            print(child)

children()에서 name변수를 이용하면 현재 parameter(layer) 위치 확인이 가능하다.

count = 0
for name, layer in model.named_childeren():
	count += 1
    if count == 1:
    	print(name) # conv1출력

torchsummary 사용하기

또는 torchsummary를 사용하여 output shape와 parameter수를 같이 확인하는 방법도 있다.

!pip install torchsummary
from torchsummary import summary as summary

model = models.resnet50(pretrained=True)
summary(model, (3,224,224)) # (model, input_size)

 

 

 

 

[Pytorch] 모델 구조 확인, parameter확인

모델 구조(architecture) 확인하기 model = models.resnet50(pretrained = True) print(model) 모델 파라미터(parameter) 확인하기 - parameters() for name, param in model.named_parameters(): count+=1 if cou..

rabo0313.tistory.com

 

'AI > PyTorch' 카테고리의 다른 글

[PyTorch] torch.nn.CrossEntropyLoss()  (0) 2022.01.19
[PyTorch] torchvision.dataset.CoCoDetection  (0) 2022.01.17
[PyTorch] 버전 변경하기  (0) 2021.11.23
[PyTorch] 모델 앙상블(ensemble) 하기  (0) 2021.10.29
[PyTorch] numpy.mean()  (0) 2021.07.12