cls_loss = nn.CrossEntropyLoss()
test_pred_y = torch.Tensor([[2,0.1,0.3],[0,1,0.3]]) # 실제 사용에선 softmax에 의해 각 행의 합이 1이 될 것이다.
test_true_y1 = torch.Tensor([1,0]).long() # 1은 true값이 1번째(클래스)라는 것을 의미
test_true_y2 = torch.Tensor([0,1]).long()
print(test_pred_y.shape)
print(test_pred_y)
print(test_true_y1)
print(test_true_y2)
print(cls_loss(test_pred_y, test_true_y1))
print(cls_loss(test_pred_y, test_true_y2))
input이 (2,3) shape이라면 각 행마다의 정답 target을 지정해주기 위해 (1,2) shape으로 만들어 준다.
import torch
import torch.nn as nn
import numpy as np
output = torch.Tensor(
[
[0.8982, 0.805, 0.6393, 0.9983, 0.5731, 0.0469, 0.556, 0.1476, 0.8404, 0.5544],
[0.9457, 0.0195, 0.9846, 0.3231, 0.1605, 0.3143, 0.9508, 0.2762, 0.7276, 0.4332]
]
)
target = torch.LongTensor([1, 5])
criterion = nn.CrossEntropyLoss()
loss = criterion(output, target)
print(loss) # tensor(2.3519)
다중분류를 위한 대표적인 손실함수, torch.nn.CrossEntropyLoss – GIS Developer
딥러닝의 많은 이론 중 가장 중요한 부분이 손실함수와 역전파입니다. PyTorch에서는 다양한 손실함수를 제공하는데, 그 중 torch.nn.CrossEntropyLoss는 다중 분류에 사용됩니다. torch.nn.CrossEntropyLoss는 nn
www.gisdeveloper.co.kr
'AI > PyTorch' 카테고리의 다른 글
[PyTorch] basic (0) | 2022.11.11 |
---|---|
[Pytorch] 텐서 쌓기 함수 torch.cat(), torch.stack() 비교 (0) | 2022.04.19 |
[PyTorch] torchvision.dataset.CoCoDetection (0) | 2022.01.17 |
[PyTorch] model의 parameter 접근하기 (0) | 2022.01.06 |
[PyTorch] 버전 변경하기 (0) | 2021.11.23 |