knowledge distillation 간단하게

이미 학습된 큰 규모의 teacher network가 있다면 작은 student network 학습시 teacher network의 지식을 전달하여 학습을 시키자.

 

 

1. 일반적인 방법

 

주어진 input x를 pretrained teacher model과 student model에 넣어서 output을 낸다

 

teacher model의 경우 softmax(T=t)를 사용하여 soft label을 내놓고

 

student model은 softmax(T=1)의 hard label과 softmax(T=t)의 soft label을 모두 내놓는다

 

knowledge distillation 개념도

 

 

A부분에서는 student model의 hard prediction을 이용하여 ground truth와의 cross entropy를 이용한 일반적인 training이 이루어짐

 

B부분에서는 teacher model의 지식 주입이 일어남

 

teacher model과 student model의 soft prediction을 이용한 KL divergence loss를 계산하여 teacher의 지식을 전달

 

 

2. hard label과 soft label

 

image를 예로 들어보면 여러 사물이 들어가도 hard label로 표현하면 단 하나만 담을 수 있음

 

 

 

위에 제시된 그림을 사자가 앞을 보고 있다고 해서 label=사자로 할수는 있지만

 

호랑이도 사자와 비슷한 만큼이나 이미지에 들어가있는데 호랑이의 정보는 완전히 무시하게 된다

 

network가 볼 때는 사자와 호랑이가 둘다 있는데 사자라고만 해버리면 오히려 혼동하게 된다는 것

 

soft label로 사자와 호랑이가 0.5, 0.5만큼 들어가있다고 표현한다면 이미지의 정보를 hard label에 비해 많이 표현할 수 있게 된다.

 

teacher model이 이런 soft label을 사용하는 이유는 soft label의 특성을 살려 데이터의 최대한 많은 정보를 student model에 전달하고자하기 위함

 

이런 사자와 호랑이가 같이 있는 상황의 데이터에는 teacher model이 ‘사자와 호랑이가 같이 있다’라는 상황 전달도 가능하게 만들었다는 것임

 

soft label 예시

 

 

밝을수록 확률이 높은부분이라고 할 때 label 2로 classification을 하겠지만 label1도 어느정도 연관이 있다는 사실을 파악할 수 있다

 

 

3. knowledge distillation loss

 

 

 

 

 

 

빨간색 부분이 student의 일반적인 학습을 나타낸 loss

 

파란색 부분이 teacher model로부터 지식을 전달받는것을 나타내는 loss

 

두 loss의 합을 최종 loss로 사용하여 student model의 training을 수행함

 

T는 temperature로 softmax의 결과 분포를 soft하게 만들어주는 hyperparameter

 

 

TAGS.

Comments