knowledge distillation 자세하게

1. background

 

“model training과 deployment 단계에서 필요한 parameter는 다르다”

 

애벌레가 번데기가 되려면 다양한 환경에서 에너지와 영양소를 잘 흡수할 수 있어야함

 

그러나 번데기에서 나비로 어른이 될 때는 이와는 매우 다른 traveling, reproduction에 대한 요구사항이 필요함

 

이 때는 영양소를 흡수하는데 주력하기보다는 몸도 가볍고 생식도 잘하도록 최적화되어야함

 

머신러닝도 이와 마찬가지임

 

training 단계와 deployment 단계에서 필요로하는 요구사항이 완전히 다르다는 것임

 

training단계에서는 애벌레가 번데기가 되기위해 에너지를 잘 흡수하던것 처럼 주어진 대용량의 데이터로부터 구조와 지식을 잘 흡수해야함

 

deployment 단계에서는 대용량의 데이터를 계속 가지고 다니면서 지식을 흡수하는 것은 상당히 낭비이고

 

현실에서 대용량의 계산 inference를 빠르고 정확하게 할수있어야함

 

실제 knowledge distillation 논문의 서문으로 등장 배경이 바로 이거

 

 

2. distillation?

 

어떤 용질이 녹아 있는 용액(예를 들어 소금물)을 가열하여 얻고자 하는 액체(순수한 물)의 끓는점에 도달하면 기체상태의 물질(수증기)이 생긴다.

 

이를 다시 냉각시켜 액체상태로 만들고 이를 모으면 순수한 액체(순수한 물)를 얻어낼 수 있는데 이것을 증류(distillation)

 

비슷한 논리로 knowledge distillation은 불순물이 섞인 여러 지식을 어떠한 방법을 통해서 엑기스만 뽑아 작게 만들어 작은 network에 전달하자는 방법

 

 

3. train stage vs. deployment stage

 

deep learning model train에는 좋은 성능(high accuracy)을 위해 모델도 커야하지만(overparameterized)

 

상당히 많은 양의 데이터가 들어가는데(using the computation power in a data center)

 

이렇게 커다란 모델을 학습시키고 mobile device에 deploy할 때 모델만 전달하면 되는데

 

대용량의 데이터도 mobile에 전달해서 들고 다니기에는 너무 낭비다

 

모델은 단순히 데이터를 train하여 representation하는 장치인데 모델만 전달하면 됐는데 굳이 대용량의 데이터까지 전달해야하는가?

 

데이터를 전달하는 이유는 device에 deploy된 model을 retrain하거나 update할 때 쓰이나봐

 

 

train후에 smartphone에 deploy할 때 설명하는 그림

 

 

datacenter에서 train한 커다란 모델의 지식을 뽑아내 작은 network에 전달

 

학습 당시에 parameter가 전부 필요하지 않다

 

실전에서 빠르게 계산을 하기 위해 최소한의 parameter만 있으면 충분하다

 

번데기 당시의 두꺼운 껍데기가 필요한 것이 아니라 실전에서 나비는 가벼운 날개를 달아 가볍게 나는게 중요함

 

 

4. knowledge transfer

 

선생님이 학생에게 가르칠 때 선생님의 지식이 학생에게 전달되어 학생이 배우는 모습에 아이디어를 얻어

 

상대적으로 large model인 teacher model이 배운 지식을 상대적으로 small model인 student model에 전달하겠다는 것이 knowledge distillation의 핵심 아이디어

 

 

 

5. transfer learning이랑 무슨 차이???

 

transfer learning은 서로 비슷한 domain에서 상대적으로 데이터 양이 많은 domain의 model로부터

 

상대적으로 데이터 양이 적은 domain의 model에 지식을 전달하여 작동하게 만드는 learning

 

예를 들어 영어와 프랑스어는 domain은 다른데 둘이 비슷한 점이 많음

 

영어의 경우 상대적으로 데이터 양이 많은데 프랑스어는 상대적으로 데이터 양이 적음

 

영어로부터 지식을 배운 model을 이용하여 상대적으로 적은 양의 데이터인 프랑스어 model이 조금 더 잘 작동하게 만들 수 있을까?

 

transfer learning을 이용하여 영어로부터 지식을 배운 model에 프랑스어 데이터를 넣어서 training을 시키면

 

(혹은 뒷단만 제거해서 fine tuning) 그냥 프랑스어 데이터를 배운 model보다는 성능이 좋은 경우가 많지

 

 

반대로 knowledge distillation은 teacher model과 student model이 학습해야하는 domain이 완전히 같다

 

model의 size를 작게 만드는 것에 중점을 둔다

 

상대적으로 큰 teacher model이 배운 지식의 엑기스를 뽑아 상대적으로 작은 student model로 전달하고자 하는 것

 

지식을 그대로 전달하면서 domain이 같은 것이고 model size를 줄이겠다는 것이 중점

 

6. 왜 knowledge distillation을 사용해야하는가?

 

그냥 student model같은 작은 모델 혼자 training하면 되는거 아닌가?

 

왜 복잡하게 teacher model을 이용해서 작은 student model을 training하는가?

 

student alone training vs. knowledge distillation training

 

 

완전히 같은 구조의 student network를 동일한 데이터셋으로 혼자서 training했을 때 정확도랑

 

teacher model을 이용한 knowledge distillation training했을 때 정확도랑 비교해보면

 

knowledge distillation했을 때 정확도가 조금 더 높은 것이 일반적

 

 

7. softmax function

 

knowledge distillation의 핵심 개념으로 변형된 softmax function을 사용함

 

 

 

위 그림이 knowledge distillation에서 사용된 변형된 softmax function의 정의를 나타내는데

 

T=1이면 이미 잘 알려진 softmax function

 

T>1이면 softmax distribution이 더욱 soft하게 만들어짐?

 

T가 커지면서 변형된 softmax 함수의 변화

 

 

softmax함수가  위 그림처럼 특정 class에 대해서만 상당히 높은 확률을 주고 나머지는 0에 가깝게 내놓는 hard distribution 경향이 있는데

 

T를 증가시키면 상당히 높은 확률값은 낮추고 0에 가까운 나머지 확률들은 조금 높여서 soft distribution 경향이 있다

 

전체 면적 합이 1이어야하니까 T를 증가시키면 봉우리가 낮아지는 만큼 평평한 부분이 조금 올라감

 

논문에서 실험적으로 최적의 T를 2.5~4정도라고 밝히고 있는데 문제마다 차이가 있으므로 직접해보지 않으면 모름

 

But when this was radically reduced to 30 units per layer, temperatures in the range 2.5 to 4 worked significantly better than higher or lower temperatures. (조건이 붙었을때 2.5~4이면 효과적이라고 되어있거든..)

 

8. 예시로 이해하는 soft distribution

 

다 합쳐서 1이 되어야하는데 일부분만 가져와서 그렇다

 

 

 

일반적으로 hard label이란 정확히 해당 target class에는 1을 주고 나머지에는 0을 주는 label 방법

 

만약 teacher network가 학습을 잘 하여 input data로부터 적절한 logit을 구해 softmax를 구하면 내놓는 output은 위 그림에서 두번째

 

해당 target class에는 0.9라는 높은 값을 주고 나머지는 거의 0에 가까운 값을 줌

 

일반적으로 이런 softmax 예측을 하면 dog 0.9로 예측하면 dog에 1을 주고 나머지는 그냥 버린다는거지

 

이렇게 내놓는 일반적인 softmax 함수의 prediction을 hard prediction이라고 부른다.

 

 

그러나 논문의 저자는 예측 결과로부터 ‘이것은 개이다’라고 하는 것만 지식이 아니다

 

결과에서도 사족동물들끼리 개와 고양이는 비슷한면이 있고 개와 차는 전혀 다른 면이 있고

 

label들끼리 사이에서도 상관관계가 있어서 개와 고양이가 가까운 정도가 개와 차가 가까운 정도보다 더 크다

 

그러나 이런 정보를 전혀 고려하지 않고 일반적으로 정답은 개다 라고만 하고 나머지는 버림

 

 

그렇지만 논문의 저자는 이렇게 prediction한 label 분포를 보면 그 안에도 숨어있는 지식을 발견할 수 있다는 것임

 

이러한 지식을 dark knowledge라고 부름

 

상식적으로 개와 고양이 사이의 거리가 고양이와 차 사이의 거리보다 가까울 필요가 있는데 0.9,0.1이랑 0.1,10^(-9)사이 거리는 상당히 아쉽다 이거임

 

그래서 softmax함수의 logit z에 T로 나눠서 이러한 점을 반영한 것이 변형된 softmax함수이고 이것 마저도 지식으로 배워야한다는 것임

 

이 함수가 내놓는 prediction을 soft prediction이라고 부른다

 

 

 

부정확한 output(고양이,차,소 등등)의 상대적인 확률을 보면 model이 어떻게 generalize하는 경향이 있는지 알려준다

 

 

9. knowledge distillation

 

input data를 pre-trained teacher model에 넣으면 해당 데이터에 대한 softmax 예측값인 soft label을 내놓는다

 

이 soft label은 일반적인 softmax 함수 예측이 아닌 T=t에서의 soft prediction에 의한 softmax 함수 예측

 

label이라고 이름이 붙은 이유는 teacher model이 내놓는 예측값이 정답이라고 보는 느낌?

 

그러니까 student model이 teacher model의 지식을 모사해야하는데 모사하기위해서는 label이 필요하니까 그렇다

 

student model은 두가지 예측을 내놓는데 teacher의 지식을 모사하기 위한 soft prediction과

 

T=1에서 softmax를 통해 실제 정답 hard label을 모사하기 위한 hard prediction

 

student의 soft prediction과 teacher가 내놓는 soft prediction을 soft label로 보고 두개의 차이를 distillation loss로 만들고

 

student의 hard prediction과 실제 정답 hard label과의 차이로 student loss를 만든다

 

 

 

참고로 student와 teacher model의 architecture는 완전히 달라도 상관없음

 

보통은 teacher가 상대적으로 parameter가 많은 모델을 쓰고 student가 상대적으로 적은 model을 쓸테지만

 

당연히 두 model이 목표로하는 task가 같아야함

TAGS.

Comments