pytorch의 tensor를 plt.imshow()했더니 TypeError: Invalid shape for image data

https://velog.io/@olxtar/Torchvision-PIL-torch.Tensor-PIL-Image

 

[Torchvision / PIL] torch.Tensor <-> PIL Image

PIL/Numpy Array/Torch Tensor 이미지끼리 변환하기 / torchvision.transforms.ToTensor() / torchvision.transforms.ToPILImage

velog.io

 

PIL이나 opencv로 이미지를 열때는 (height, width, channel) 순으로 shape를 가지게 된다

 

img = Image.open('/content/karina.jpeg')
img2 = cv2.imread('/content/karina.jpeg')

print(np.array(img).shape)
print(img2.shape)

(1132, 900, 3)
(1132, 900, 3)

 

 

이렇게 (height, width, channel) 순으로 저장된 경우에는 matplotlib의 시각화가 적용되는데

 

 

 

문제는 pytorch에서 딥러닝할때, ToTensor()로 image를 tensor로 바꿔서 dataset을 만드는데,

 

(channel, height, width)순으로 저장한다는 사실

 

f = torchvision.transforms.ToTensor()

img_tensor = f(img)

print(img_tensor.shape)

torch.Size([3, 1132, 900])

 

 

근데, 이 상태에서 plt.imshow()를 적용하면 에러가 난다

 

 

 

그래서 이 경우에는 permute 함수로 shape를 (height, width, channel)로 바꿔주고 plt.imshow()를 적용해야한다

 

permute(1,2,0)은 shape의 0번 index가 channel인데 2번 index로 옮겨주고

 

1번 index는 height인데 0번 index로 옮겨주고

 

2번 index는 width인데 1번 index로 옮긴다는 의미

 

#shape (0,1,2) >>> (1,2,0)
#shape (channel,height,width) >>> (height,width,channel)
plt.imshow(img_tensor.permute(1,2,0))

 

 

 

추가적으로 PIL이 image를 0~255 pixel로 저장하는데 ToTensor()는 pixel을 0~1로 저장한다는 것도 눈여겨볼만하다

 

 

 

 

https://stackoverflow.com/questions/53623472/how-do-i-display-a-single-image-in-pytorch

 

How do I display a single image in PyTorch?

How do I display a PyTorch Tensor of shape (3, 224, 224) representing a 224x224 RGB image? Using plt.imshow(image) gives the error: TypeError: Invalid dimensions for image data

stackoverflow.com

 

TAGS.

Comments