self-attention代码实现
时间: 2023-09-22 22:15:04 浏览: 119
self-attention-cv:专注于计算机视觉的各种自我关注机制的实现。 进行中的资料库
Sure, here's an example code implementation for self-attention in Python:
```python
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, hidden_size):
super(SelfAttention, self).__init__()
self.hidden_size = hidden_size
self.query = nn.Linear(hidden_size, hidden_size)
self.key = nn.Linear(hidden_size, hidden_size)
self.value = nn.Linear(hidden_size, hidden_size)
self.softmax = nn.Softmax(dim=2)
def forward(self, x):
batch_size = x.size(0)
# Project inputs through linear layers
query = self.query(x).view(batch_size, -1, 1, self.hidden_size)
key = self.key(x).view(batch_size, -1, self.hidden_size, 1)
value = self.value(x).view(batch_size, -1, 1, self.hidden_size)
# Calculate attention scores
scores = torch.matmul(query, key)
scores = scores / torch.sqrt(torch.tensor(self.hidden_size).float())
# Apply softmax to get attention weights
attention_weights = self.softmax(scores)
# Multiply attention weights with values
weighted_values = torch.matmul(attention_weights, value).squeeze()
return weighted_values
# Example Usage
hidden_size = 64
seq_length = 10
batch_size = 2
input_data = torch.randn(batch_size, seq_length, hidden_size)
self_attention = SelfAttention(hidden_size)
output = self_attention(input_data)
print(output.size()) # Prints: torch.Size([2, 10, 64])
```
This code implements a simple self-attention mechanism using linear layers in PyTorch. The `SelfAttention` class takes the `hidden_size` as input and defines three linear layers for query, key, and value projections. In the `forward` method, the inputs are projected through these layers, and attention scores are calculated using matrix multiplication. Softmax is then applied to obtain attention weights, which are multiplied with values to obtain the final weighted outputs. Finally, an example usage is shown to demonstrate the output size.
阅读全文