_, argmax = torch.max(outputs, 1)
Parameters
- input ( Tensor ) – 입력 텐서.
- dim (int) – the dimension to reduce.
- keepdim ( bool ) – 출력 텐서가 dim 유지 되었는지 여부 . 기본값 : False .
리턴 namedtuple (values, indices) values 의 각 행의 최대 값 input 주어진 차원에서 텐서 dim . 그리고 indices 는 발견 된 각 최대 값 (argmax)의 인덱스 위치입니다.
첫번째는 최대값을 tensor로 두번째는 index값을 tensor로 리턴한다.
torch.max / dim 역할
torch.max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor)
input의 형태가 [A,B,C,D]라고 할 때
dim=n 이라고 하면 n번째를 제외한 output이 나오게 되고
ex) dim = 2, C가 빠져 [A,B,D]가 나오게 된다.
이 C에 해당하는 데이터를 기준으로 최대값 및 인덱스가 튜플로 나오게 된다.
참고로 예시에서 [3,-1,5] 이 부분은 -1에 해당하는 부분은 차원이 맞게 끔 알아서 조절된다.
input이 [3,5,2,2] 였기 때문에 -1에은 4가 들어가게 된다.
import torch
aa = torch.randint(0,5,[3,5,2,2])
# print(aa)
bb = torch.reshape(aa,[3,-1,5])
print("INPUT : ")
print(bb)
c = torch.max(bb,dim=2)
print("dim=2")
print(c)
d = torch.max(bb,dim=1)
print("dim=1")
print(d)
d = torch.max(bb,dim=0)
print("dim=0")
print(d)
d = torch.max(bb,dim=-1)
print("dim= -1")
print(d)
Output :
INPUT :
tensor([[[3, 4, 2, 0, 3],
[0, 4, 0, 3, 0],
[2, 0, 4, 3, 3],
[1, 3, 3, 0, 4]],
[[0, 3, 4, 2, 0],
[0, 3, 4, 0, 2],
[4, 0, 3, 2, 2],
[3, 1, 1, 0, 1]],
[[3, 4, 3, 0, 2],
[2, 4, 1, 3, 1],
[4, 0, 1, 1, 4],
[4, 4, 1, 1, 0]]])
dim=2
torch.return_types.max(
values=tensor([[4, 4, 4, 4],
[4, 4, 4, 3],
[4, 4, 4, 4]]),
indices=tensor([[1, 1, 2, 4],
[2, 2, 0, 0],
[1, 1, 0, 0]]))
dim=1
torch.return_types.max(
values=tensor([[3, 4, 4, 3, 4],
[4, 3, 4, 2, 2],
[4, 4, 3, 3, 4]]),
indices=tensor([[0, 0, 2, 1, 3],
[2, 0, 0, 0, 1],
[2, 0, 0, 1, 2]]))
dim=0
torch.return_types.max(
values=tensor([[3, 4, 4, 2, 3],
[2, 4, 4, 3, 2],
[4, 0, 4, 3, 4],
[4, 4, 3, 1, 4]]),
indices=tensor([[0, 0, 1, 1, 0],
[2, 0, 1, 0, 1],
[1, 0, 0, 0, 2],
[2, 2, 0, 2, 0]]))
dim= -1
torch.return_types.max(
values=tensor([[4, 4, 4, 4],
[4, 4, 4, 3],
[4, 4, 4, 4]]),
indices=tensor([[1, 1, 2, 4],
[2, 2, 0, 0],
[1, 1, 0, 0]]))
torch.max / dim 역할
input의 형태가 A,B,C,D라고 할 때dim=n 이라고 하면 n번째를 제외한 output이 나오게 되고 ex) dim = 2, C가 빠져 A,B,D가 나오게 된다.이 C에 해당하는 데이터를 기준으로 최대값 및 인덱스가 튜플로 나오게
velog.io
내가 작성한 코드↓
import torch
input = torch.Tensor([[2,5,4,6],[3,2,1,8],[7,8,9,5]])
lab = torch.Tensor([3, 3, 1])
a,b = torch.max(input,1)
print(a)
print(b)
'AI > PyTorch' 카테고리의 다른 글
[PyTorch] torch.tensor와 torch (0) | 2021.06.29 |
---|---|
[PyTorch] torch.tensor.float(), torch.mean() (0) | 2021.06.29 |
[PyTorch] Conv1d, Conv2d 차이점 (0) | 2021.06.29 |
[PyTorch] CNN 설계 7 - 9 (0) | 2021.06.15 |
[PyTorch] CNN 설계 6. train, validation, test 함수 정의 (0) | 2021.06.15 |