multi-head attentiond 개념 알아보고 간단하게 구현해보기

지금까지 이야기한 것은 word embedding vector들의 self attention을 단 1번만 수행했다는 점인데 이것을 확장하여 여러번 수행하고 싶다는 것이다.

 

왜 여러번 수행해야할까?

 

단 1번의 self attention은 1가지 측면에서만 word들의 attention 측면을 고려하지만 필요에따라 attention 측면을 여러 방면에서 수행할 필요가 있다.

 

특히 매우 긴 문장의 경우 ‘I went to the school. I studied hard. I came back home. I took the rest.’를 생각해보자.

 

이 문장을 해석하기 위해 단어 I에 대해서 고려해야할 대상은 went, studied, came, took 등 동사 측면도 있지만 그것의 대상이되는 school, home, rest 등 장소, 목적어 측면에서도 고려할 필요가 있다.

 

하나의 주어진 sequence에서 특정 word에 대해 여러 측면에서의 attention 정보를 모두 고려하고 싶다는 것이다.

 

이렇게 하면 앞에서 이야기했던 ‘자기 자신에게 너무 집중하는 문제’를 줄여서 다른 단어에 더욱 집중할 수 있게 한다.

 

 

1. 예시로 이해하는 multi-head attention

 

 

word sequence matrix X를 multi head인 head1,head2,....,head7 각각에 전부 넣는다.

 

각 head1,head2,head3,...,head7은 서로 다른 $W_{i}^{Q}, W_{i}^{K}, W_{i}^{V}$ ,i=1,2,3,...,7을 가지고 X의 query,key,value를 만든다.

 

그리고 각 head에서 얻은 attention 결과가 7개의 matrix로 나타날 것이다.

 

 

이 7개의 attention 결과를 모두 사용하기위해 그동안 사용했던 기술인 concat을 하는 것인데 concat했더니 차원이 너무 크다.

 

기본적으로 input vector와 output vector의 차원은 맞추고 싶거든..

 

이러한 차원을 줄이는 linear transformation $W^{O}$를 생각하였고 이것을 이용하여 input X와 차원을 동일하게 만든 하나의 multi head attention 결과를 얻는다.

 

 

2. multi head attention에서 주의해야할 부분

 

그동안 multi head attention을 어떻게 생각했냐면

 

하나의 word sequence matrix X를 여러개의 head1,head2,head3,....,head7 각각에 집어 넣는다.

 

그러면 (Q1,K1,V1), (Q2,K2,V2), (Q3,K3,V3),...(Q7,K7,V7)이라는 서로 다른 7개의 Q,K,V set을 만들고

 

각 head에서 각 Q,K,V set을 이용한 attention 연산 attention1,....attention7을 얻는다

 

이들의 concat을 한 뒤에 차원을 줄이는 선형변환을 수행하여 output을 얻는다.

 

근데 이렇게 생각하면 논문의 이 그림을 오해할 수가 있다.

 

 

어라?? word input matrix X를 넣어서 Q,K,V를 만드는 것은 알겠는데 이것을 linear layer에 넣는다고??? 그리고 나서 attention을 한다고??

 

그러면서 이 식이 이해가 안되는거임

 

 

뭐라고?? $QW^{Q}, KW^{K}, VW^{V}$를 attention 한다고??? 아니 word에 선형변환 행렬을 곱해야 Q,K,V가 나오는데???

 

논문에서 제시하는 multi head attention의 아이디어는 word sequence matrix X를 먼저 Q,K,V로 만든다.

 

그리고 Q,K,V의 선형변환을 attention 연산으로 사용한다는 것이다.

 

Q,K,V의 선형변환이 무슨 의미를 갖느냐? word가 attention하는 측면(path)을 다양하게 만들겠다는 것이다.

 

근데 사실 Q,K,V의 선형변환이  (Q1,K1,V1), (Q2,K2,V2), (Q3,K3,V3),...(Q7,K7,V7) 이렇게 7개의 SET을 만들면서 새로운 Q,K,V가 되는것이다.

 

그러니까 $QW_{1}^{Q} = Q_{1}$이고 $QW_{2}^{Q} = Q_{2}$,........ 이런 뜻

 

근데 이제 선형변환이라는 사실을 인지 못하면 이해를 못하는거지.

 

we found it beneficial to linearly project the queries, keys and values h times with different, learned linear projections to dk, dk and dv dimensions, respectively.

 

On each of these projected versions of queries, keys and values we then perform the attention function in parallel, yielding dv-dimensional output values. These are concatenated and once again projected, resulting in the final values, as depicted in Figure 2.

 

논문에서는 Q,K,V를 projection하고 attention을 수행해서 이들의 concatenation을 구하고 여기에 다시 projection을 한다고..

 

 

근데 중요하지는 않다. 결국 multi head attention을 관통하는 원리는 Q,K,V를 여러개 만들었다는 사실이 중요하다.

 

 

3. multi head attention의 실제 구현

 

정통 이론은 word embedding matrix를 Q,K,V로 만든 뒤 head수 만큼 $W^{Q},W^{K},W^{V}$를 준비하여

 

attention을 수행하고 결과를 합친 뒤 $W^{O}$로 차원을 줄인다.

 

실제 구현에서는 이렇게 하면 연산량이 너무 많아 메모리가 부족해진다.

 

그래서 하나의 Q,K,V를 만든 뒤 각각을 head수인 H개로 쪼개서 각각에서 attention을 수행하고 결과를 합치면서 차원을 줄이는 연산을 수행한다.

 

각 벡터의 차원은 서로 다른 정보를 가지기 때문에 여러 측면에서 attention하고자하는 원래의 아이디어에 위배되지 않으면서도 연산량을 줄이면서 구현할 수 있게 된다.

 

직관적인 multi head attention의 구현 그림

 

 

 

#multi head attention 내에 쓰이는 linear transformation matrix 정의
w_q = nn.Linear(d_model, d_model)
w_k = nn.Linear(d_model, d_model)
w_v = nn.Linear(d_model, d_model)

w_o = nn.Linear(d_model,d_model)

#(B,L,d_model)
q = w_q(batch_emb)
k = w_k(batch_emb)
v = w_v(batch_emb)

print(q.shape)
print(k.shape)
print(v.shape)
torch.Size([10,20,512])
torch.Size([10,20,512])
torch.Size([10,20,512])

#Q,K,V를 num_head개의 차원 분할된 여러개의 vector로 만든다
batch_size = q.shape[0]
d_k = d_model//num_heads

#(B,L,num_heads,d_k)
q = q.view(batch_size, -1, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k)

 

 

이론에서는 마치 attention을 n개 써서 q,k,v를 n개 뽑아내는듯이 설명했는데

 

코딩으로는 한번 attention해서 q,k,v를 n개 split해서 썼다? 이게 맞는건가?

 

원래 논문에서 설명한 multi headed attention의 아이디어는 서로 다른 head attention의 linear transformation을 이용해

 

하나의 word vector에 대한 서로 다른 Q,K,V를 구성하고 각각에서 self attention 연산을 수행한 뒤 concat하여 차원을 줄인다는 것이다.

 

그러나 실제 구현할 때는 hidden state vector의 차원을 head 개수로 쪼개서 나온 작은 vector를 각각 하나의 head에 의해 나왔다고 가정하고 각각에서 attention을 수행한 뒤 합친다.

 

그러니까 너가 이해한 내용이 기본적으로 맞다.

 

그런데 정통식으로 multi-headed attention을 구현하면 중간에 벡터 차원이 head개수만큼 늘어나

 

원래 transformer 특성상 메모리를 많이 잡아먹는데 메모리를 더 잡아먹는 단점으로 작용한다.

 

그리고 이렇게 split해서 구현해도 되는 이유는

 

어차피 hidden vector의 각 차원이라는 것이 각각의 차원에서 자신이 보고있는 word vector의 특성 요소를 하나씩 담당하므로

 

이를 적절히 쪼개 subvector 관점에서 attention을 수행하여 합쳐도 결국 각각 다른 특성에 focus하는

 

multi-headed attention의 원래 아이디어에 위배되지 않는다.

TAGS.

Comments