transformer demo
python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
assert self.head_dim * num_heads == d_model, "d_model must be divisible by num_heads"
self.wq = nn.Linear(d_model, d_model)
self.wk = nn.Linear(d_model, d_model)
self.wv = nn.Linear(d_model, d_model)
self.fc = nn.Linear(d_model, d_model)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# Linear transformation
Q = self.wq(query)
K = self.wk(key)
V = self.wv(value)
# Split into multiple heads
Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# Scaled Dot-Product Attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, V)
# Concatenate heads
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# Final linear layer
output = self.fc(attn_output)
return output
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff):
super(FeedForward, self).__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
def forward(self, x):
return self.fc2(F.relu(self.fc1(x)))
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super(EncoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-attention
attn_output = self.self_attn(x, x, x, mask)
x = x + self.dropout(attn_output)
x = self.norm1(x)
# Feed-forward network
ff_output = self.feed_forward(x)
x = x + self.dropout(ff_output)
x = self.norm2(x)
return x
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super(DecoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
# Self-attention
attn_output = self.self_attn(x, x, x, tgt_mask)
x = x + self.dropout(attn_output)
x = self.norm1(x)
# Cross-attention
attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
x = x + self.dropout(attn_output)
x = self.norm2(x)
# Feed-forward network
ff_output = self.feed_forward(x)
x = x + self.dropout(ff_output)
x = self.norm3(x)
return x
class Transformer(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout=0.1):
super(Transformer, self).__init__()
self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
self.positional_encoding = nn.Parameter(torch.zeros(1, max_seq_length, d_model))
self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
self.fc_out = nn.Linear(d_model, tgt_vocab_size)
self.dropout = nn.Dropout(dropout)
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
# Embedding + Positional Encoding
src_seq_length, tgt_seq_length = src.size(1), tgt.size(1)
src = self.dropout(self.encoder_embedding(src) + self.positional_encoding[:, :src_seq_length, :])
tgt = self.dropout(self.decoder_embedding(tgt) + self.positional_encoding[:, :tgt_seq_length, :])
# Encoder
enc_output = src
for layer in self.encoder_layers:
enc_output = layer(enc_output, src_mask)
# Decoder
dec_output = tgt
for layer in self.decoder_layers:
dec_output = layer(dec_output, enc_output, src_mask, tgt_mask)
# Output
output = self.fc_out(dec_output)
return output
# 训练脚本
def train_transformer():
# 超参数
src_vocab_size = 1000
tgt_vocab_size = 1000
d_model = 64
num_heads = 4
num_layers = 2
d_ff = 128
max_seq_length = 20
dropout = 0.1
batch_size = 32
num_epochs = 10
learning_rate = 0.0001
# 创建模型
model = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index=0) # 忽略填充索引
# 生成随机数据
src = torch.randint(1, src_vocab_size, (batch_size, max_seq_length))
tgt = torch.randint(1, tgt_vocab_size, (batch_size, max_seq_length))
tgt_y = torch.randint(1, tgt_vocab_size, (batch_size, max_seq_length))
# 训练循环
for epoch in range(num_epochs):
optimizer.zero_grad()
output = model(src, tgt)
loss = criterion(output.view(-1, tgt_vocab_size), tgt_y.view(-1))
loss.backward()
optimizer.step()
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
print("训练完成!")
if __name__ == "__main__":
train_transformer()