knowledge distillation에 대하여 기초

1. basic idea

 

이미 학습을 완료한 teacher model의 지식을 더 작은 student model에게 주입하는 것이 목적

 

큰 모델에서 작은 모델로 지식을 전달하여 모델 압축에 유용하게 쓴다

 

최근에는 teacher model의 출력을 unlabeled data의 pseudo-label로 자동생성하는 방식으로 활용함

 

 

2. unsupervised learning

 

label이 존재하지 않는 동일한 input을 이미 학습한 teacher model과 학습이 안된 student model에 넣어 각각 output을 출력

 

같은 입력에 대해 서로 다른 출력을 내놓을 것인데 둘의 차이를 구해주는 KL divergence loss를 구한다

 

이 loss를 최소화하는 방향으로 backpropagation에 의해 student model만 학습한다.

 

그러면서 student model이 teacher model의 행동을 따라하게 된다

 

 

참고로 input X는 teacher model이 학습할때 썼던 데이터 일수도 있지만 전혀 다른 데이터여도 사실 상관없다

 

 

3. supervised learning

 

label이 있는 데이터의 경우 True label이 존재하니까 student model의 출력과 비교하여 cross entropy를 이용한 student loss를 구할 수 있다

 

student의 예측이 정답을 맞추도록 학습을 진행하는 것이다.

 

바로 위에서 했던 teacher model과 student model의 출력 차이를 나타내는 KL divergence loss인 distillation loss는 여전히 구할 수 있다

 

soft label과 soft prediction은 하나의 distribution이므로 KL divergence loss를 구할 수 있다

 

근데 사실 KL divergence loss나 cross entropy loss나 거기서 거기라 cross entropy loss를 구한다고도 한다(논문에서는 특별히 없고 KL은 있는것 같다)

 

아무튼 이 loss(KL divergence)는 student가 teacher를 얼마나 잘 따라하는지 측정하는 것이라고 할 수 있다.

 

 

 

 

knowledge distillation은 두가지 loss를 한꺼번에 줄이는 것을 목표로 하고 있다.

 

teacher와 student의 softmax loss를 닮게 만든다 +  student의 순수 분류 성능 loss도 같이 줄이는 것을 목표

 

 

 

첫번째는 student가 혼자서 분류하는 성능을 나타내는 loss로 student의 예측값과 ground-truth의 차이에 대한 cross-entropy loss를 최소화시킨다

 

두번째는 위에서 설명한 teacher와 student가 각각 분류한 softmax logit의 차이를 cross-entropy loss로 하여 최소화시키는 방법으로 학습을 진행한다.

 

 

 

4. soft prediction?

 

주목해야할 첫번째 포인트는 soft prediction이다

 

label은 hard label과 soft label이 있다.

 

hard label은 class를 one hot vector로 나타내는 것이고

 

soft label은 각 class에 속할 확률을 나타내는 것으로 softmax vector이다

 

사실 softmax가 hard하다는 것은 vector 원소 값들이 편차가 매우 커서 one hot vector에 가까워지는 것이고

 

soft하다는 것은 vector 원소 값들이 편차가 작다는 것이다.

 

 

soft label의 값들은 모델이 가지고 있는 지식이라고 볼 수 있다.

 

모델이 하나의 input을 받아들였을 때 input이 각 class에 속하는 정도가 어느 정도일지 생각하는 정도가 각각 확률로 나타나는 것이 soft prediction

 

5. Temperature

 

두번째로 주목할 부분은 softmax(T=t)이다

 

일반적인 softmax(X)는 X가 크면 매우 커지고 X가 작으면 매우 작아지는 편차가 큰 함수이다.

 

왜냐하면 지수함수 자체가 값이 조금만 커져도 지수적으로 크게 증가하는 경향이 있어서 그렇다

 

하지만 충분히 큰 T에 대해 softmax(X/T)는 그런 효과를 줄여서

X의 크기에 무관하게 편차가 없이 더욱 soft한 값들을 가질 수 있게 해준다.

 

왜냐하면 T로 나누는 것이 X가 큰 것에 대해 큰 penalty를 주는 효과가 있다.

 

예시1

 

예시2

 

이렇게 큰 값으로 나눴을 때 softmax vector의 편차가 줄어든다는 사실은 transformer에서 이미 증명했다

 

temperature라는 표현이 되게 재미있는 표현인데 soft label의 값들은 모델이 각 class에 대해 가지고 있는 지식을 의미한다고 볼 수 있는데

 

temperature T가 클수록 distillation이 더 잘 일어나잖아?

 

즉, T가 클수록 softmax vector의 값들이 더욱 편차가 줄어들어 더욱 soft해져서 각 class의 값들이 분명하게 살아나

 

다른 모델로 각 class의 지식들이 더욱 잘 전달된다

 

그러니까 teacher model의 지식이 student model로 잘 전달되어 student가 teacher를 더욱 잘 따라하게 된다.

 

softmax 값들이 있어야 각 class에 대해 배울게 있다는 거임

 

T는 hyperparameter이고 2~4정도일때 효과적이라고 논문에 알려져있다

 

 

6. semantic information을 크게 고려하지 않는다

 

무슨 말이냐면 teacher에서 나온 label의 개별 원소 각각을 student가 따라하려 한다기보다는

 

전체 형태가 하나의 추상적인 지식을 표현하며 그것을 따라하는게 중요하다는 것이다.??? 근데 이게 중요한 말인지는 모르겠다

 

 

7. backpropagation

 

distillation loss와 student loss의 가중합을 전체 loss로 보고 이것을 최소화하는 방향으로 student model을 학습

 

gradient가 student model로만 흐른다는 소리이다.

 

backpropagation을 직관적으로 그림

 

 

https://light-tree.tistory.com/196

 

딥러닝 용어 정리, Knowledge distillation 설명과 이해

이 글은 제가 공부한 내용을 정리하는 글입니다. 따라서 잘못된 내용이 있을 수도 있습니다. 잘못된 내용을 발견하신다면 리플로 알려주시길 부탁드립니다. 감사합니다. Knowledge distillation 이란?

light-tree.tistory.com

 

TAGS.

Comments