nn.Embedding( ) 用法是什么
时间: 2023-08-11 14:10:40 浏览: 93
`nn.Embedding()` 是 PyTorch 中的一个类,用于创建一个嵌入层(Embedding Layer),将离散的整数 ID 映射为连续的实数向量。
`nn.Embedding(num_embeddings, embedding_dim, padding_idx=None)` 的参数如下:
- `num_embeddings`:词典大小,即词汇表中单词的个数;
- `embedding_dim`:嵌入向量的维度;
- `padding_idx`:指定一个 padding 的索引,将这个索引对应的嵌入向量设置为全 0 向量。
`nn.Embedding()` 的输入是一个 LongTensor,形状为 $(batch\_size, seq\_len)$,其中每个元素是一个整数 ID,表示词汇表中的一个单词。输出是一个 FloatTensor,形状为 $(batch\_size, seq\_len, embedding\_dim)$,其中每个元素是一个实数,表示对应单词的嵌入向量。
使用示例:
```python
import torch.nn as nn
# 创建一个词汇表,包含 1000 个单词
vocab_size = 1000
# 嵌入向量的维度为 128
embedding_dim = 128
# 创建一个嵌入层
embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
# 输入为一个 LongTensor,形状为 (batch_size, seq_len)
input_ids = torch.LongTensor([[1, 2, 3], [4, 5, 6]])
# 将输入传递到嵌入层中,得到输出
# 输出的形状为 (batch_size, seq_len, embedding_dim)
output = embedding(input_ids)
```
阅读全文