ResNet의 핵심 아이디어인 skip connection과 Residual learning
1. deep neural network는 학습하기가 어렵다
overfitting이란 train error가 줄어드는데도 test error는 증가하는, 방향이 반대되는 현상으로 parameter 수가 늘어나면 일반적으로 발생한다.
일반적으로 deep한 neural network는 shallow한 network에 비해 학습하기가 어렵다.
train error가 줄어들면서 test error도 어느정도 줄어드니까 위와 같은 경우는 overfitting은 아니다.
물론 test error가 너무 커지는게 문제다.
아무리 학습을 잘 시킨다고해도 결국엔 20-layer가 56-layer보다 나았음
왜 학습하기가 어려웠나? 깊을수록 gradient vanishing 문제가 발생했기 때문이다
ResNet은 skip connection이라는 아이디어를 통해 gradient가 더 잘 전달되도록 하였다
직관적으로 생각해보자.
일반적인 layer는 input x가 direct로 전달되어 f(x)로 완전히 변형되어 나오는데
skip connection layer는 x에서 변형된 f(x)가 나오면서 여기에 x의 정보를 그대로 가지는 x를 더해준다
그만큼 깊을수록 x가 완전히 변형되어 f(x)는 x의 정보를 갖지 못하는데
skip connection으로 변형된 f(x)에 원래의 x의 정보를 더해주겠다는 의미
1층이라고 생각하면 이해하기 어렵지만 x가 이전 layer에서 나온 output이라고 생각하면
이미 학습된 값인 x의 정보를 보존하지 않는 f(x)와 이미 학습된 값 x를 보존하는 f(x)+x 뭐가 더 좋겠는가? 당연히 후자다.
2. Residual learning
ResNet의 핵심적인 아이디어
일반적인 neural net은 input x에 대한 target y를 mapping 하는 함수 H(x)를 찾는 것이어서
cost인 H(x)-y를 최소화하는 방향으로 찾는다.
그러니까 원하는 함수를 H(x)라하면 x와의 잔차함수 H(x)-x=F(x)라 하자. 그러면 원하는 함수는 F(x)+x가 된다.
이것을 구현한 microsoft는 H(x)-x를 학습시키는것이 H(x)를 학습시키는 것보다 더 쉽다고 가설을 세운 것임.
특별한 이유없이 잘 될것 같다는 생각에 실험을 해본거고 실제로 잘 된거임
자연스러운 생각인 것이 H(x)에 x의 정보를 뺀 H(x)-x가 당연히 학습해야할 정보가 적음
최종적으로 학습된 결과에 x를 더해주면 그것이 우리가 원하는 함수 H(x)와 동일하게 됨. 이것이 residual learning이다
일반적으로 local optimum에 덜 빠지게 된다는 것이 수학적으로 증명되었다.
그리고 gradient vanishing이 잘 일어나지 않는다고 한다.
3. ResNet의 특별한 구조
3-1) shortcut
그냥 바로 더하는 것이 simple shortcut
보통 이 simple shortcut이 성능이 좋음
학습시킨 결과와 차원이 안맞으면 1*1 conv로 차원을 줄여서 output과 차원을 맞추고 더해주는 것이 projected shortcut 그런데 잘 안쓴다고 한다
3-2) batch normalization
batch normalization은 ReLU전이냐 후냐 하지말아야하냐 논란이 있다는데?
3-3) bottleneck architecture
오른쪽이 bottleneck architecture
1*1 conv를 3*3 conv 앞 뒤로 배치하여 일단 channel을 줄이고 conv를 통과시킨뒤
다시 channel을 늘리는 1*1 conv를 사용하여 input과 output 크기를 동일하게 하면서
layer는 늘리고 parameter는 줄임
참고
https://89douner.tistory.com/64
'딥러닝 > Computer Vision' 카테고리의 다른 글
가장 좋은 data augmentation이 있을까?(random augmentation) (0) | 2022.02.13 |
---|---|
컴퓨터비전에서 사용하는 기본적인 data augmentation 1 (0) | 2022.02.12 |
1*1 convolution은 왜 중요한가? (0) | 2022.02.08 |
GoogleNet의 핵심 아이디어 inception module, auxiliary classifier, 1*1 convolution 알아보기 (0) | 2022.02.08 |
VGGNet는 왜 3*3 convolution을 사용했을까? (0) | 2022.02.07 |