torch.where()로 tensor내 특정 원소의 위치를 찾기

gaussian heatmap에서 landmark localization (x,y) 좌표를 얻어내는 방법은

 

landmark localization이라는 것이 가장 주목할 부분, heatmap에서 가장 밝게 빛나는 부분이므로

 

heatmap의 activation value중 가장 큰 값의 (x,y)좌표를 얻어오면 된다

 

주어진 heatmap tensor hm에서 최댓값 부분을 어떻게 찾아오느냐?

 

hm에서 최댓값을 가져오려면 hm.max()나 torch.max(hm)을 사용한다

 

activation heatmap의 표현과 최대 activation value

 

 

hm==torch.max(hm)을 하면 True, False를 원소로 가지는 hm 크기와 동일한 tensor가 나온다

 

boolean tensor 표현

 

 

True의 위치를 찾는게 목적이라고 할 수 있다. 어떻게 찾을까?

 

torch.where()나 np.where()함수는 condition을 받아 True의 위치를 반환해준다

 

 

 

np.where()는 numpy array를 쓴다는 점만 차이 있다

 

x,y를 쓰지 않는다면 torch.where(condition)은 condition에 맞는 index만 tuple로 반환함

 

 

 

dim=0이 y좌표, dim=1이 x좌표

 

max의 위치는 dim=0에서 y=6 dim=1에서 x=44로 (44,6)에 있다는 것을 알 수 있다

 

 

 

참고로 torch.nonzero(condition,as_tuple=True)를 사용해도 동일하다

 

False가 0이고 True가 1이라 0이 아닌 곳을 찾아달라해서

 

 

TAGS.

Comments