Parallel Transformer代码
时间: 2023-07-02 07:18:54 浏览: 117
Transformer代码
5星 · 资源好评率100%
以下是使用PyTorch实现的Parallel Transformer代码,其中包括了模型定义、数据预处理、训练和测试等部分:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.nn.utils.rnn import pad_sequence
from torch.optim import Adam
import argparse
import os
import time
import math
from tqdm import tqdm
class ParallelTransformerModel(nn.Module):
def __init__(self, vocab_size, embed_dim, num_heads, num_layers, hidden_dim, dropout):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.pos_encoding = PositionalEncoding(embed_dim, dropout)
self.transformer_layers = nn.ModuleList([
TransformerEncoderLayer(embed_dim, num_heads, hidden_dim, dropout)
for _ in range(num_layers)
])
self.fc = nn.Linear(embed_dim, vocab_size)
def forward(self, x):
x = self.embedding(x)
x = self.pos_encoding(x)
for layer in self.transformer_layers:
x = layer(x)
x = self.fc(x)
return x
class PositionalEncoding(nn.Module):
def __init__(self, embed_dim, dropout=0.1, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, embed_dim)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
class TransformerEncoderLayer(nn.Module):
def __init__(self, embed_dim, num_heads, hidden_dim, dropout):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
self.fc1 = nn.Linear(embed_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, embed_dim)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x):
x_norm = self.norm1(x)
attn_out, _ = self.self_attn(x_norm, x_norm, x_norm)
x = x + self.dropout1(attn_out)
x_norm = self.norm2(x)
fc_out = self.fc2(F.relu(self.fc1(x_norm)))
x = x + self.dropout2(fc_out)
return x
class TextDataset(Dataset):
def __init__(self, data_file, vocab_file):
self.data = []
self.vocab = {}
self.max_len = 0
with open(vocab_file, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
self.vocab[line.strip()] = idx
with open(data_file, 'r', encoding='utf-8') as f:
for line in f:
tokens = line.strip().split()
if self.max_len < len(tokens):
self.max_len = len(tokens)
self.data.append([self.vocab[token] for token in tokens])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def collate_fn(self, batch):
batch = pad_sequence([torch.tensor(data) for data in batch], batch_first=True)
return batch
def train(args, model, dataloader, criterion, optimizer, epoch):
model.train()
epoch_loss = 0
for batch in tqdm(dataloader, desc=f'Train epoch {epoch}'):
optimizer.zero_grad()
inputs, targets = batch[:, :-1], batch[:, 1:]
inputs, targets = inputs.to(args.device), targets.to(args.device)
outputs = model(inputs)
loss = criterion(outputs.view(-1, outputs.shape[-1]), targets.view(-1))
loss.backward()
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(dataloader)
def evaluate(args, model, dataloader, criterion, epoch):
model.eval()
epoch_loss = 0
with torch.no_grad():
for batch in tqdm(dataloader, desc=f'Eval epoch {epoch}'):
inputs, targets = batch[:, :-1], batch[:, 1:]
inputs, targets = inputs.to(args.device), targets.to(args.device)
outputs = model(inputs)
loss = criterion(outputs.view(-1, outputs.shape[-1]), targets.view(-1))
epoch_loss += loss.item()
return epoch_loss / len(dataloader)
def main(args):
torch.manual_seed(args.seed)
# Initialize distributed training
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
# Load and preprocess data
train_dataset = TextDataset(args.train_file, args.vocab_file)
eval_dataset = TextDataset(args.eval_file, args.vocab_file)
train_sampler = DistributedSampler(train_dataset)
eval_sampler = DistributedSampler(eval_dataset)
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=train_dataset.collate_fn, num_workers=args.num_workers, sampler=train_sampler)
eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size, collate_fn=eval_dataset.collate_fn, num_workers=args.num_workers, sampler=eval_sampler)
# Initialize model and optimizer
model = ParallelTransformerModel(len(train_dataset.vocab), args.embed_dim, args.num_heads, args.num_layers, args.hidden_dim, args.dropout)
model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = Adam(model.parameters(), lr=args.learning_rate)
# Train and evaluate
for epoch in range(1, args.num_epochs + 1):
train_loss = train(args, model, train_dataloader, criterion, optimizer, epoch)
eval_loss = evaluate(args, model, eval_dataloader, criterion, epoch)
# Average loss across all processes
train_loss = torch.tensor(train_loss).to(args.device)
eval_loss = torch.tensor(eval_loss).to(args.device)
torch.distributed.reduce(train_loss, dst=0)
torch.distributed.reduce(eval_loss, dst=0)
if args.local_rank == 0:
train_loss /= torch.distributed.get_world_size()
eval_loss /= torch.distributed.get_world_size()
print(f'Train epoch {epoch}, loss: {train_loss.item():.4f}')
print(f'Eval epoch {epoch}, loss: {eval_loss.item():.4f}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--train_file', type=str, default='train.txt')
parser.add_argument('--eval_file', type=str, default='eval.txt')
parser.add_argument('--vocab_file', type=str, default='vocab.txt')
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--num_epochs', type=int, default=10)
parser.add_argument('--embed_dim', type=int, default=256)
parser.add_argument('--num_heads', type=int, default=8)
parser.add_argument('--num_layers', type=int, default=6)
parser.add_argument('--hidden_dim', type=int, default=1024)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--learning_rate', type=float, default=0.001)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
args.device = torch.device('cuda', args.local_rank)
main(args)
```
注:以上代码只是一个示例,具体实现可能因环境、数据等因素而异。
阅读全文