nn.Linear()등으로 정의한 파라미터 접근은 parameter(), named_parameters()으로 가능하다. 정확히는 layer가 모두 nn.Module()을 상속받으므로 Module에 정의되어 있는 parameter 접근 방법을 사용하면 된다.
1. torch.nn.Module.parameters()
parameter()는 layer 이름을 제외한 parameter값에 대한 iterator를 준다.
layer = torch.nn.Linear(10,3)
layer
>> Linear(in_features=10, out_features=3, bias=True)
for p in layer.parameters():
print(p)
>> Parameter containing:
tensor([[-0.2232, 0.0369, -0.2201, -0.1385, -0.1104, 0.2852, 0.0249, -0.0295,
0.0382, 0.2847],
[-0.0430, 0.3032, 0.1541, -0.3093, 0.1008, -0.3134, -0.1431, 0.0280,
0.2178, 0.2094],
[ 0.1900, -0.2386, 0.1099, 0.1769, -0.0338, -0.2079, 0.0816, -0.0180,
-0.0182, 0.1578]], requires_grad=True)
Parameter containing:
tensor([ 0.1549, -0.0628, 0.1692], requires_grad=True)
2. torch.nn.Module.named_parameters()
named_parameters()는 (name, parameter) 조합의 tuple iterator를 준다. 이때 출력물은 named_tuple이 아니니 헷갈리지 말자.
for p in layer.named_parameters():
print(p)
>> ('weight', Parameter containing:
tensor([[-0.2232, 0.0369, -0.2201, -0.1385, -0.1104, 0.2852, 0.0249, -0.0295,
0.0382, 0.2847],
[-0.0430, 0.3032, 0.1541, -0.3093, 0.1008, -0.3134, -0.1431, 0.0280,
0.2178, 0.2094],
[ 0.1900, -0.2386, 0.1099, 0.1769, -0.0338, -0.2079, 0.0816, -0.0180,
-0.0182, 0.1578]], requires_grad=True))
('bias', Parameter containing:
tensor([ 0.1549, -0.0628, 0.1692], requires_grad=True))
# 이름만 출력
for name, p in layer.named_parameters():
print(name)
>> weight
bias
모델 파라미터(parameter) 확인하기 - children()
for child in model.children():
count += 1
if count == 2:
print(child)
children()에서 name변수를 이용하면 현재 parameter(layer) 위치 확인이 가능하다.
count = 0
for name, layer in model.named_childeren():
count += 1
if count == 1:
print(name) # conv1출력
torchsummary 사용하기
또는 torchsummary를 사용하여 output shape와 parameter수를 같이 확인하는 방법도 있다.
!pip install torchsummary
from torchsummary import summary as summary
model = models.resnet50(pretrained=True)
summary(model, (3,224,224)) # (model, input_size)
[Pytorch] 모델 구조 확인, parameter확인
모델 구조(architecture) 확인하기 model = models.resnet50(pretrained = True) print(model) 모델 파라미터(parameter) 확인하기 - parameters() for name, param in model.named_parameters(): count+=1 if cou..
rabo0313.tistory.com
'AI > PyTorch' 카테고리의 다른 글
[PyTorch] torch.nn.CrossEntropyLoss() (0) | 2022.01.19 |
---|---|
[PyTorch] torchvision.dataset.CoCoDetection (0) | 2022.01.17 |
[PyTorch] 버전 변경하기 (0) | 2021.11.23 |
[PyTorch] 모델 앙상블(ensemble) 하기 (0) | 2021.10.29 |
[PyTorch] numpy.mean() (0) | 2021.07.12 |