728x90
안녕하세요
딥러닝에서는 여러 임베딩을 결합해서 새로운 입력으로 활용하는 경우가 많습니다.
예를 들어 텍스트 임베딩 + 감정 임베딩, 사용자 임베딩 + 아이템 임베딩, 단어 임베딩 + 위치 임베딩 등이 있습니다.
이번 포스팅에서는 PyTorch에서 자주 사용되는 벡터(임베딩) 합치기 방식 5가지에 대해서
간단한 예제 코드와 실행 결과를 함께 살펴보겠습니다.
우선 터미널을 통해 파이토치를 설치해줍니다.
> pip install torch
전체 코드입니다.
"""
PyTorch 벡터 합치기 예제
- 덧셈 (Sum)
- 평균 (Mean)
- Concatenation
- 가중합 (Weighted Sum)
- Concatenation + Linear (Projection)
"""
import torch
import torch.nn as nn
# ==========================================
# 1. 입력 임베딩 정의
# ==========================================
# 예시로 2개의 임베딩 벡터(batch=2, dim=4)를 준비
a = torch.tensor([[1.0, 2.0, 3.0, 4.0],
[0.5, 1.5, 2.5, 3.5]])
b = torch.tensor([[0.1, 0.2, 0.3, 0.4],
[1.0, 1.0, 1.0, 1.0]])
print("===== 입력 벡터 =====")
print("a:\n", a)
print("b:\n", b)
print()
# ==========================================
# 2. 덧셈 (Element-wise Sum)
# ==========================================
sum_emb = a + b
print("===== 덧셈 결과 (a + b) =====")
print(sum_emb, "\n")
# ==========================================
# 3. 평균 (Element-wise Mean)
# ==========================================
mean_emb = (a + b) / 2
print("===== 평균 결과 ((a+b)/2) =====")
print(mean_emb, "\n")
# ==========================================
# 4. Concatenation (벡터 이어붙이기)
# ==========================================
# dim=1 : 특징 차원 기준으로 붙임
concat_emb = torch.cat([a, b], dim=1)
print("===== Concatenation 결과 (cat[a,b]) =====")
print(concat_emb, "\n")
# ==========================================
# 5. 가중합 (Weighted Sum)
# ==========================================
alpha = 0.7 # a의 비중
weighted_emb = alpha * a + (1 - alpha) * b
print("===== 가중합 결과 (αa + (1-α)b) =====")
print(weighted_emb, "\n")
# ==========================================
# 6. Concatenation + Linear (Projection)
# ==========================================
# concat 후 차원이 8이 되었으니, Linear로 다시 4차원으로 축소
linear = nn.Linear(8, 4)
projected_emb = linear(concat_emb)
print("===== Concatenation + Linear 결과 =====")
print(projected_emb, "\n")
실행 결과입니다.
'Python' 카테고리의 다른 글
[Python] gRPC로 양방향 통신하기 (0) | 2025.08.22 |
---|---|
[Python] 음성 데이터 품질 검사(QC) 자동화 리포트 만들기 (0) | 2025.08.20 |
[Python] JSONL 포맷으로 음성 데이터셋 정리하기 (1) | 2025.08.18 |
[Python] librosa로 WAV 파일 무음 제거하기 (1) | 2025.08.17 |
[Python] PyTorch 활용해서 손글씨 데이터를 숫자로 분류하기 (2) | 2025.08.13 |