transformer decoder에 사용된 masked self attention에 대해 알아보고 구현하기

decoder는 특이하게 masked multi head attention을 먼저 수행한다.

 

이것은 decoder 내부에서 이루어지는 self attention 과정으로 decoder의 input sequence끼리만 이루어진다.

 

언어모형을 학습시킬때 이미 정답을 아는 상태에서 학습을 시킨다.

 

i go home을 번역하라고 할때 decoder에는 '<sos> 나는 집에 간다'를 넣고 '나는 집에 간다 <eos>'를 순차적으로 뱉게 학습을 시킨다는 것이다.

 

언어 모형의 학습과정

 

decoder에 input으로 '<SOS>'를 넣어주면 output이 '나는'이 나오길 바라고

 

'나는'을 input으로 넣어주면 '집에'가 나오길 바란다. 그런식으로 학습을 시킨다.

 

하지만 이런 학습이 inference에는 어울리지 않는다는 점이다.

 

test과정에서는 정답을 모른채로 <sos>를 넣으면 어떤 예측값을 만들고 이 예측값이 정답이든 아니든 모른채로 다음 예측을 위한 input으로 사용한다.

 

decoder의 masked multi attention은 이런 속임수를 방지하고싶은 것이다.

 

decoder가 오직 현 시점의 이전 단어들의 정보에만 집중하도록 한다는 것이다.

 

다음 단어들의 정보에도 집중해버리면 그것은 예측 관점에서 말이 안되니까

 

 

input sequence과정에서 attention을 위한 query, key,value를 만드는데 query, key의 내적 결과가 다음과 같다면

 

 

query,key의 내적 결과를 matrix로 만들었음

 

다음 단어에는 집중하지 못하도록, 오직 현 step의 이전 단어에만 집중할 수 있도록 대각선 위 부분의 score를 0으로 만들어버림..

 

 

다음 단어에 집중하게되는 부분은 마스크로 가리듯이 0으로 만들어버림

 

마지막으로 확률의 합은 1이어야하니까 1이 안되는 부분은 normalization을 시킨다.

 

방법은 각 행의 성분을 현재 행의 성분의 합으로 나눈다.

 

0.91,0,0은 0.91로 각각 나눠서 1,0,0

 

(0.42,0.47,0)은 0.89로 나눠서 (0.42/0.89 , 0.47/0.89, 0)

 

 

이렇게 얻어진 score로 attention을 수행하여 hidden vector를 생성한다.

 

residual connection, layer normalization을 수행한 뒤에 decoder의 2번째 multi head attention의 query로 사용하게 된다.

 

 

1. masked self attention의 구현

 

sequence 데이터에 pad가 있는 경우에는 pad도 마스크를 해서 계산을 더욱 줄인다.

 

pad는 의미없는 데이터이기때문에 굳이 attention 시킬 필요가 없다.

 

 

 

실제 구현 코드, padding mask와 0으로 만들어야하는 mask를 만들어 둘을 합친다.

 

 

 

 

 

실제는 이제 마스크부분을 -inf로 주고 softmax를 취하는 과정을 거쳐 0으로 만든다

 

 

 

 

 

TAGS.

Comments