Pytorch의 computational graph와 backward()에 대해 이해하기
1. computational graph
computational graph라는 것은 pytorch가 최종 변수에 대한 (위 그림에서는 L)
forward pass를 통해 계산되는 모든 과정이 graph 형태로 저장되어 있는 것을 의미한다.
위의 그림은 a, b, c, d, L, w1,w2,w3,w4 9개의 변수 값들의 계산 과정이 기록되어 있는 computational graph이다.
위 그림에서 예를 들어 c를 계산할려면 a와 w2의 어떤 연산으로 c가 계산되어진다는 의미다.
이렇게 저장을 해놓으면 chain rule에 의한 backward pass 계산이 쉬워진다.
2. backward
forward pass를 통해 변수에 대해 계산을 하면 pytorch에서 알아서 computational graph를 계산해놓는다.
최종 변수 L에 대하여 L.backward()를 하면 L에 대한 모든 미분 $\frac{dL}{d?}$을 알아서 계산해놓는다
원하는 특정 미분값을 불러오고 싶으면?
예를 들어 $\frac{dL}{da}$는 a.grad로 불러올 수 있다
pytorch에서 backward는 auto gradient의 핵심 함수라고 할 수 있는데
가장 기본은 single element를 가진 tensor에 backward를 걸어야 미분이 계산된다
tensor가 아닌 그냥 scalar를 쓰면 backward를 할 수 없다는 에러가 남
y가 multiple tensor인데 backward()를 걸면
‘grad can be implicitly created only for scalar outputs’라고 오직 스칼라 출력에 대해서만 grad를 계산할 수 있다고 나옴
단일 스칼라 element를 가지는 tensor에 대해서만 backward를 할 수 있다는 이야기.
참고로 torch.randn(<원소 수>)는 원소 수만큼 random으로 N(0,1)에서 sample을 뽑아서 tensor를 생성
----------------------------------------------------------------------------------------------------------------------------------------------------
함수에 걸어야되는거 아니냐 이런 생각을 했었는데 스칼라 tensor 에 backward를 걸면
알아서 그 스칼라가 구해지는 함수를 추적해서 미분을 알아서 해줌
out은 7개 원소를 가지는 tensor고 max_score는 스칼라임
이렇게 해서 7개 원소를 가지는 tensor out에 backward를 걸어보면
1차원이 아닌 tensor에 backward를 걸면 이런 오류가 뜬다
----------------------------------------------------------------------------------------------------------------------------------------------------
추가적으로 float single tensor인 경우에만 backward를 걸 수 있다.
float tensor가 아니면 requires_grad 조차도 할 수 없다
loss가 criterion = nn.CrossEntropyloss()로 해가지고 함수로 인식되어가지고..
loss.backward()만 하다보니까
backward가 함수에 걸어야되는거 아니냐 이런 생각을 했었는데
스칼라 tensor에 backward를 걸면 알아서 그 스칼라 tensor가 구해지는 함수를 추적해서 미분을 알아서 해줌
requires_grad=True 조건은 해당 tensor에 대한 모든 연산을 추적하겠다는 의미로
해당 tensor에 대해서는 gradient를 계산할 수 있다는 뜻이며 조건이 없으면(False이면) gradient를 계산하지 못한다면서 에러가 남
참고로 require_grad=True이면 tensor를 출력할 때 grad_fn = <> 같은 정보가 나타남
모든 조건에 맞게 backward를 통해 $\frac{dw}{dx}$를 계산하면
computational graph는 이런 식으로 그려져 있을 것이다
'프로그래밍 > Pytorch' 카테고리의 다른 글
Pytorch에서 learning rate scheduler 사용하는 방법 알기 (0) | 2024.04.17 |
---|---|
torch.where()로 tensor내 특정 원소의 위치를 찾기 (0) | 2024.04.14 |
Pytorch에서 두 tensor가 서로 같은지 비교하고 싶다면? (0) | 2024.04.11 |
pretrained된 computer vision 모델에서 마지막 linear layer는 제거하고 feature만 뽑는법 (0) | 2024.04.11 |
NLP text data 전처리에서 tokenizing할 때 padding이 필요한 이유 (0) | 2024.03.31 |