본문 바로가기
Python

[Python] 간단한 Attention 메커니즘 구현하기 (Query, Key, Value 이해하기)

by teamnova 2025. 8. 28.
728x90

 

안녕하세요 오늘은 딥러닝 모델에서 핵심으로 쓰이는 Attention 매커니즘을 파이썬 코드로 직접 구현해보겠습니다. 

Attention 이란 "어떤 입력이 현재 상황에서 얼마나 중요한지" 를 계산해 가중치를 주는 방식을 의미합니다. 

 

Attention 이란 

사람도 문장을 읽을 때 모든 단어를 동일하게 보지 않고 상황에 따라 더 중요하다고 판단되는 단어에 집중(attention) 합니다. 

딥러닝에서의 Attention 개념 또한, 현재 생성해야할 출력 (ex. 다음 글자, 다음 음향 프레임 등) 에 맞춰 과거 정보 중에서 관련성이 높은 부분에 더 큰 가중치를 주고, 그 가중치를 활용해 다음 출력을 만듭니다. 

즉, Attention은 모델이 ‘모든 입력을 동일하게 보지 않고, 중요한 부분에 집중하게 만드는 장치’입니다.

 

모델은 매 스텝마다 "지금 내가 만들고 싶은 것(Query)"과 "데이터베이스 내의 각 조각들(key)" 을 비교해 "유사도 점수" 를 계산합니다. 점수가 높을수록 지금 필요한 정보에 가깝다는 뜻이 됩니다. 

이 점수는 최종 출력에 반영되어, 모델이 더 정교한 결과를 낼 수 있게 합니다

 

유사성에 따른 가중치 점수가 필요한 이유

1. 노이즈 억제 : 관련 없는 정보는 낮은 가중치를 적용해 영향을 최소화합니다. 사용자가 기대한 대답과 거리가 먼 대답이 출력될 가능성을 낮추는 것입니다. 

2. 모호성 해소 : 비슷한 후보들 중 문맥상 맞는 것을 크게 키워 동음이의어 or 모호한 상황을 풀어집니다. 

3. 장기 의존성 처리 : 물리적으로 떨어져있는 정보라도, 유사도만 높다면 즉시 참조할 수 있어 긴 문맥을 다룰때 유리합니다! 

4. 다각적 관계 포착(멀티헤드): attention 개념에는 헤드 개념이 필수적입니다. 여러개의 관점(음향데이터일 경우 의미, 리듬 등) 을 동시에 고려해 다각도에서 연관성을 판단합니다. 

 

기본 개념 정리 

Attention이란?

  • 입력 시퀀스(텍스트, 음성 프레임 등) 중 중요한 부분에 더 집중하도록 가중치를 주는 메커니즘.
  • 번역, 음성 합성, 이미지 캡션 등 다양한 AI 모델의 핵심 기술.

Query, Key, Value

  • Query(Q): 내가 지금 찾고 싶은 정보
  • Key(K): 참고할 수 있는 후보들의 이름표
  • Value(V): 후보가 가진 실제 정보
    → Query와 Key의 유사도를 비교해 가중치(Attention Score)를 구하고, 그 비율만큼 Value들을 섞어 최종 결과를 만듭니다.

Self-Attention vs Cross-Attention

  • Self-Attention: Q, K, V가 같은 시퀀스에서 나옴 (문장 내 단어들끼리 관계 계산)
  • Cross-Attention: Q는 한쪽(예: 텍스트), K/V는 다른 쪽(예: 음성)에서 나옴

 

1. Attention 예제 코드 

 

 

1) 함수 정의

def scaled_dot_product_attention(query, key, value):

 

  • Attention의 핵심을 함수화.
  • query, key, value는 보통 배치(batch) 단위로 들어오기 때문에 (batch_size, seq_len, dim) 형태의 텐서.

 

 

2) 유사도 점수 계산

scores = torch.matmul(query, key.transpose(-2, -1))
  • Query (Q)와 Key (K)의 내적(dot product)을 통해 유사도 점수를 구함.
  • key.transpose(-2, -1) : 마지막 두 차원을 바꿔서 행렬 곱이 가능하게 변환.
  • 즉, Query 하나당 Key 전체와의 관련도를 모두 뽑아낸다.

 

3) 스케일링 (scaling) 

d_k = key.size(-1)
scores = scores / (d_k ** 0.5)

 

 

  • 차원 수가 커질수록 dot product 값이 기하급수적으로 커짐 → softmax가 한쪽으로 몰려버리는 문제 발생.
  • 라서 보통 차원의 제곱근으로 나눠서 값의 크기를 안정화시킴. 
  • 이름 그대로 Scaled Dot-Product Attention 의 “Scaled” 부분

 

*soft max 란 ? : 주어진 점수(실수 벡터) 를 0~1 사이값으로 변환하고 그값들의 합이 1이 되도록 정규화 하는 함수. 

 

4) Softmax 적용 

attention_weights = F.softmax(scores, dim=-1)
  • softmax 함수 -> 점수를 확률 분포(0~1 사이, 합=1) 로 변환.
  • dim=-1은 Key 방향으로 softmax 적용 → Query가 각 Key에 얼마나 집중할지 확률화한다. 

 

 

5) Value 와 결합 및 반환 (return) 

output = torch.matmul(attention_weights, value)

return output, attention_weights

 

  • Attention Weights × Value → 가중합.
  • 즉, “Query가 Key와 비교해 중요도가 높은 Value를 더 많이 섞어낸 결과.”

 

output: 최종 컨텍스트 벡터 (모델이 다음 단계에 참고할 정보)