光写函数,不写if__main__函数怎么执行
时间: 2023-07-21 21:20:11 浏览: 130
您说的很对,我漏掉了`if __name__ == '__main__'`判断,下面是一个完整的包含`if __name__ == '__main__'`判断的代码,您可以直接执行。
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class TransformerBlock(nn.Module):
def __init__(self, input_dim, hidden_dim, num_heads, dropout_rate):
super(TransformerBlock, self).__init__()
self.self_attn = nn.MultiheadAttention(input_dim, num_heads)
self.feed_forward = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_dim, input_dim),
nn.Dropout(dropout_rate)
)
self.layer_norm1 = nn.LayerNorm(input_dim)
self.layer_norm2 = nn.LayerNorm(input_dim)
def forward(self, x, mask=None):
# Multi-head self-attention layer
attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)
# Layer normalization and residual connection
x = self.layer_norm1(x + attn_output)
# Feed-forward layer
ffn_output = self.feed_forward(x)
# Layer normalization and residual connection
x = self.layer_norm2(x + ffn_output)
return x
class TransformerModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, num_heads, dropout_rate):
super(TransformerModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.pos_encoding = nn.Parameter(torch.zeros(1, 100, embedding_dim))
self.dropout = nn.Dropout(dropout_rate)
self.num_layers = num_layers
self.transformer_blocks = nn.ModuleList([
TransformerBlock(embedding_dim, hidden_dim, num_heads, dropout_rate)
for _ in range(num_layers)
])
self.fc = nn.Linear(embedding_dim, vocab_size)
def forward(self, x):
x = self.embedding(x) + self.pos_encoding[:, :x.size(1), :]
x = self.dropout(x)
mask = self._generate_mask(x)
for i in range(self.num_layers):
x = self.transformer_blocks[i](x, mask)
x = self.fc(x)
return x
def _generate_mask(self, x):
mask = torch.zeros(x.size(1), x.size(1)).bool()
mask = torch.triu(mask, diagonal=1)
return mask.to(x.device)
if __name__ == '__main__':
# Example input sequence
input_seq = torch.tensor([[1, 2, 3, 4]])
# Model hyperparameters
vocab_size = 10000
embedding_dim = 100
hidden_dim = 256
num_layers = 4
num_heads = 8
dropout_rate = 0.1
# Create the model
model = TransformerModel(vocab_size, embedding_dim, hidden_dim, num_layers, num_heads, dropout_rate)
# Forward pass
output = model(input_seq)
print(output)
```
这个代码包含了一个Transformer模型的完整实现和一个简单的示例,可以直接运行并查看输出。
阅读全文