Decoder
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
assert self.head_dim * num_heads == d_model, "d_model must be divisible by num_heads"
self.wq = nn.Linear(d_model, d_model)
self.wk = nn.Linear(d_model, d_model)
self.wv = nn.Linear(d_model, d_model)
self.fc = nn.Linear(d_model, d_model)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# Linear transformation
Q = self.wq(query)
K = self.wk(key)
V = self.wv(value)
# Split into multiple heads
Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# Scaled Dot-Product Attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, V)
# Concatenate heads
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# Final linear layer
output = self.fc(attn_output)
return output
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff):
super(FeedForward, self).__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
def forward(self, x):
return self.fc2(F.relu(self.fc1(x)))
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super(DecoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-attention
attn_output = self.self_attn(x, x, x, mask)
x = x + self.dropout(attn_output)
x = self.norm1(x)
# Feed-forward network
ff_output = self.feed_forward(x)
x = x + self.dropout(ff_output)
x = self.norm2(x)
return x
class TransformerDecoder(nn.Module):
def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
super(TransformerDecoder, self).__init__()
self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
def forward(self, x, mask=None):
for layer in self.layers:
x = layer(x, mask)
return x
# 运行案例
if __name__ == "__main__":
# 定义模型参数
num_layers = 2
d_model = 64
num_heads = 4
d_ff = 128
dropout = 0.1
# 创建模型
decoder = TransformerDecoder(num_layers, d_model, num_heads, d_ff, dropout)
# 创建输入数据 (batch_size, seq_len, d_model)
batch_size = 2
seq_len = 10
x = torch.rand(batch_size, seq_len, d_model)
# 创建掩码 (batch_size, seq_len, seq_len)
mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).expand(batch_size, -1, -1)
# 前向传播
output = decoder(x, mask)
print("Input shape:", x.shape)
print("Output shape:", output.shape)