본문 바로가기

AI/PyTorch

[PyTorch] torch.max(outputs, 1)

 _, argmax = torch.max(outputs, 1)

 

Parameters

  • input ( Tensor ) – 입력 텐서.
  • dim (int) – the dimension to reduce.
  • keepdim ( bool ) – 출력 텐서가 dim 유지 되었는지 여부 . 기본값 : False .

리턴 namedtuple (values, indices) values 의 각 행의 최대 값 input 주어진 차원에서 텐서 dim . 그리고 indices 는 발견 된 각 최대 값 (argmax)의 인덱스 위치입니다.

 

첫번째는 최대값을 tensor로 두번째는 index값을 tensor로 리턴한다.

 

torch.max / dim 역할

torch.max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor)

input의 형태가 [A,B,C,D]라고 할 때
dim=n 이라고 하면 n번째를 제외한 output이 나오게 되고
ex) dim = 2, C가 빠져 [A,B,D]가 나오게 된다.
이 C에 해당하는 데이터를 기준으로 최대값 및 인덱스가 튜플로 나오게 된다.

참고로 예시에서 [3,-1,5] 이 부분은 -1에 해당하는 부분은 차원이 맞게 끔 알아서 조절된다.
input이 [3,5,2,2] 였기 때문에 -1에은 4가 들어가게 된다.

 

import torch
aa = torch.randint(0,5,[3,5,2,2])

# print(aa)

bb = torch.reshape(aa,[3,-1,5])
print("INPUT : ")
print(bb)
c = torch.max(bb,dim=2)
print("dim=2")
print(c)
d = torch.max(bb,dim=1)
print("dim=1")
print(d)
d = torch.max(bb,dim=0)
print("dim=0")
print(d)
d = torch.max(bb,dim=-1)
print("dim= -1")
print(d)

 

Output :

INPUT : 
tensor([[[3, 4, 2, 0, 3],
         [0, 4, 0, 3, 0],
         [2, 0, 4, 3, 3],
         [1, 3, 3, 0, 4]],

        [[0, 3, 4, 2, 0],
         [0, 3, 4, 0, 2],
         [4, 0, 3, 2, 2],
         [3, 1, 1, 0, 1]],

        [[3, 4, 3, 0, 2],
         [2, 4, 1, 3, 1],
         [4, 0, 1, 1, 4],
         [4, 4, 1, 1, 0]]])
dim=2
torch.return_types.max(
values=tensor([[4, 4, 4, 4],
        [4, 4, 4, 3],
        [4, 4, 4, 4]]),
indices=tensor([[1, 1, 2, 4],
        [2, 2, 0, 0],
        [1, 1, 0, 0]]))
dim=1
torch.return_types.max(
values=tensor([[3, 4, 4, 3, 4],
        [4, 3, 4, 2, 2],
        [4, 4, 3, 3, 4]]),
indices=tensor([[0, 0, 2, 1, 3],
        [2, 0, 0, 0, 1],
        [2, 0, 0, 1, 2]]))
dim=0
torch.return_types.max(
values=tensor([[3, 4, 4, 2, 3],
        [2, 4, 4, 3, 2],
        [4, 0, 4, 3, 4],
        [4, 4, 3, 1, 4]]),
indices=tensor([[0, 0, 1, 1, 0],
        [2, 0, 1, 0, 1],
        [1, 0, 0, 0, 2],
        [2, 2, 0, 2, 0]]))
dim= -1
torch.return_types.max(
values=tensor([[4, 4, 4, 4],
        [4, 4, 4, 3],
        [4, 4, 4, 4]]),
indices=tensor([[1, 1, 2, 4],
        [2, 2, 0, 0],
        [1, 1, 0, 0]]))

 

 

 

 

 

 

torch.max / dim 역할

input의 형태가 A,B,C,D라고 할 때dim=n 이라고 하면 n번째를 제외한 output이 나오게 되고 ex) dim = 2, C가 빠져 A,B,D가 나오게 된다.이 C에 해당하는 데이터를 기준으로 최대값 및 인덱스가 튜플로 나오게

velog.io

 

 

내가 작성한 코드↓

 

import torch

input = torch.Tensor([[2,5,4,6],[3,2,1,8],[7,8,9,5]])

lab = torch.Tensor([3, 3, 1])
a,b = torch.max(input,1)
print(a)
print(b)

 

output