Pytorch model forward에서 에러나는 경우 대처하기(input output model shape print해보기)
pytorch에서 model에 input을 넣어 forward 과정을 거쳐 output을 낼려고 할 때 종종 에러가 나는데
input으로 3d, 4d, 5d를 받아야하는데 2d가 들어왔다고 말하는거
딥러닝은 모델이 너무 복잡하기 때문에 머릿속에서 생각만으로 어디가 문제인지 알아내기 어렵다
모델 내부에서나, input, output등에 대해 중간중간에 shape를 찍어봐야함
중간에 grad_CAM의 shape를 찍어보면 실제로 2d라는 걸 확인할 수 있음
2d를 2번 unsqueeze(0)해서 4d로 만들고 넣었더니 더 이상 에러가 없었다
왜 2번했냐고? 1번만 하면 또 에러나서 그래
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
비슷하게 모델에 img를 그냥 넣으면 shape때문에
4차원을 넣어야하는데 3차원을 넣었다며 에러난다.
img.unsqueeze(0)으로 차원을 높여서 넣어주면 된다
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
preds=model(imgs)[0]과 preds=model(imgs)를 했을 때 차이를 생각해보자
model(imgs)[0]으로 하면 hm의 dimension은 80*80으로 2차원을 가진다
이번엔 preds=model(imgs)만 해본다면
22*80*80으로 3차원 tensor가 된다. 왜 이런일이 일어날까?
model의 forward pass를 살펴보는 것이 은근 중요할때가 있다
forward pass에 재밌는 부분이 out=[]로 리스트를 반환한다는 것
preds를 쳐보면 리스트 안에 하나의 output tensor가 들어있는 형태로 나온다
그 안의 tensor의 차원이 8*22*80*80이라는 것을 기억해보자
이게 왜 잘 구분해야하는지 생각해보자.
우리의 목표는 activation heat map의 최댓값의 x,y 좌표인데 그럴려면 activation map이 2차원형태로 얻어져야한다.
만약에 preds=model(imgs)라고 쓴다면
for img,pred_hm in zip(imgs,preds):에서 preds 리스트의 원소 하나를 call하여 pred_hm으로 가져오니
pred_hm은 차원이 8*22*80*80인 tensor이고 사실 이것은 22*80*80이 8개 있는 것이다
그래서 다음 줄 for hm in pred_hm:에서 hm이 22*80*80이 된다
반면 preds = model(imgs)[0]이라고 쓴다면
먼저 8*22*80*80 tensor로 시작하고 사실 이것은 22*80*80이 8개 있는 것과 동일하다
그래서 다음 줄 for img,pred_hm in zip(imgs,preds):에서 pred_hm이 그 중 1개인 22*80*80 tensor이 되는 것이고
사실 이것은 80*80이 22개 있는 것과 동일하다.
그래서 다음 줄 for hm in pred_hm:에서 hm이 그 중 1개인 80*80으로 나올 수 있다
근데 사실 이것을 논쟁할 이유가 없는 것이 명확히 맞는 답은 preds=model(imgs)[0]이 맞다.
왜냐하면 preds=model(imgs)은 실제 output이 1개 들어간 리스트잖아.
그런데 원하는 것은 리스트 안에 들어간 output이니까
아무튼 결론은 생각을 잘하자. 모델 구조를 잘 살펴보면서 아무 생각없이 코딩하지말고
'프로그래밍 > Pytorch' 카테고리의 다른 글
Unexpected key(s) in state_dict: "model.electra.embeddings.position_ids". 에러 해결하기 (0) | 2024.10.30 |
---|---|
Pytorch에서 padding sequence vs. packed sequence 차이 이해하고 구현하기 (0) | 2024.04.19 |
Pytorch에서 learning rate scheduler 사용하는 방법 알기 (0) | 2024.04.17 |
torch.where()로 tensor내 특정 원소의 위치를 찾기 (0) | 2024.04.14 |
Pytorch의 computational graph와 backward()에 대해 이해하기 (0) | 2024.04.13 |