본문 바로가기
Python

[Python] Pytorch에서 벡터 합치기 (임베딩 결합 방식)

by teamnova 2025. 8. 19.
728x90

 

안녕하세요

딥러닝에서는 여러 임베딩을 결합해서 새로운 입력으로 활용하는 경우가 많습니다.
예를 들어 텍스트 임베딩 + 감정 임베딩, 사용자 임베딩 + 아이템 임베딩, 단어 임베딩 + 위치 임베딩 등이 있습니다.

이번 포스팅에서는 PyTorch에서 자주 사용되는 벡터(임베딩) 합치기 방식 5가지에 대해서
간단한 예제 코드와 실행 결과를 함께 살펴보겠습니다.

 

우선 터미널을 통해 파이토치를 설치해줍니다.
> pip install torch 

전체 코드입니다.

 

"""
PyTorch 벡터 합치기 예제
- 덧셈 (Sum)
- 평균 (Mean)
- Concatenation
- 가중합 (Weighted Sum)
- Concatenation + Linear (Projection)
"""

import torch
import torch.nn as nn

# ==========================================
# 1. 입력 임베딩 정의
# ==========================================
# 예시로 2개의 임베딩 벡터(batch=2, dim=4)를 준비
a = torch.tensor([[1.0, 2.0, 3.0, 4.0],
                  [0.5, 1.5, 2.5, 3.5]])
b = torch.tensor([[0.1, 0.2, 0.3, 0.4],
                  [1.0, 1.0, 1.0, 1.0]])

print("===== 입력 벡터 =====")
print("a:\n", a)
print("b:\n", b)
print()

# ==========================================
# 2. 덧셈 (Element-wise Sum)
# ==========================================
sum_emb = a + b
print("===== 덧셈 결과 (a + b) =====")
print(sum_emb, "\n")

# ==========================================
# 3. 평균 (Element-wise Mean)
# ==========================================
mean_emb = (a + b) / 2
print("===== 평균 결과 ((a+b)/2) =====")
print(mean_emb, "\n")

# ==========================================
# 4. Concatenation (벡터 이어붙이기)
# ==========================================
# dim=1 : 특징 차원 기준으로 붙임
concat_emb = torch.cat([a, b], dim=1)
print("===== Concatenation 결과 (cat[a,b]) =====")
print(concat_emb, "\n")

# ==========================================
# 5. 가중합 (Weighted Sum)
# ==========================================
alpha = 0.7  # a의 비중
weighted_emb = alpha * a + (1 - alpha) * b
print("===== 가중합 결과 (αa + (1-α)b) =====")
print(weighted_emb, "\n")

# ==========================================
# 6. Concatenation + Linear (Projection)
# ==========================================
# concat 후 차원이 8이 되었으니, Linear로 다시 4차원으로 축소
linear = nn.Linear(8, 4)
projected_emb = linear(concat_emb)
print("===== Concatenation + Linear 결과 =====")
print(projected_emb, "\n")

 

 

실행 결과입니다.