본문 바로가기

AI/PyTorch

[PyTorch] torch.nn.CrossEntropyLoss()

cls_loss = nn.CrossEntropyLoss()


test_pred_y = torch.Tensor([[2,0.1,0.3],[0,1,0.3]]) # 실제 사용에선 softmax에 의해 각 행의 합이 1이 될 것이다.
test_true_y1 = torch.Tensor([1,0]).long() # 1은 true값이 1번째(클래스)라는 것을 의미
test_true_y2 = torch.Tensor([0,1]).long()

print(test_pred_y.shape)

print(test_pred_y)
print(test_true_y1)
print(test_true_y2)

print(cls_loss(test_pred_y, test_true_y1))
print(cls_loss(test_pred_y, test_true_y2))

input이 (2,3) shape이라면 각 행마다의 정답 target을 지정해주기 위해 (1,2) shape으로 만들어 준다.

 

import torch
import torch.nn as nn
import numpy as np
output = torch.Tensor(
    [
        [0.8982, 0.805, 0.6393, 0.9983, 0.5731, 0.0469, 0.556, 0.1476, 0.8404, 0.5544],
        [0.9457, 0.0195, 0.9846, 0.3231, 0.1605, 0.3143, 0.9508, 0.2762, 0.7276, 0.4332]
    ]
)
target = torch.LongTensor([1, 5])
criterion = nn.CrossEntropyLoss()
loss = criterion(output, target)
print(loss) # tensor(2.3519)

 

 

 

 

다중분류를 위한 대표적인 손실함수, torch.nn.CrossEntropyLoss – GIS Developer

딥러닝의 많은 이론 중 가장 중요한 부분이 손실함수와 역전파입니다. PyTorch에서는 다양한 손실함수를 제공하는데, 그 중 torch.nn.CrossEntropyLoss는 다중 분류에 사용됩니다. torch.nn.CrossEntropyLoss는 nn

www.gisdeveloper.co.kr