nn.Embedding()
时间: 2023-12-10 09:46:55 浏览: 163
nn.Embedding() 是 PyTorch 中的一个模块,用于实现词嵌入(Word Embedding)操作。词嵌入是将离散的词语映射到连续的向量空间中的一种方式,它可以捕捉到词语之间的语义关系。
nn.Embedding() 的调用方式如下:
```python
embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=None)
```
其中,`num_embeddings` 是词表的大小,也是词嵌入矩阵的行数;`embedding_dim` 是每个词嵌入向量的维度;`padding_idx` 是可选参数,用于指定需要进行填充的词的索引。
以下是一个示例:
```python
import torch
import torch.nn as nn
# 定义一个词嵌入层
embedding = nn.Embedding(10000, 300)
# 输入一个大小为 (3, 4) 的张量
input = torch.LongTensor([[1,2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
# 将输入张量中的每个词索引映射为对应的词嵌入向量
output = embedding(input)
print(output.size()) # 输出: torch.Size([3, 4, 300])
```
在这个示例中,我们定义了一个词嵌入层,词表大小为10000,每个词嵌入向量的维度为300。然后,我们输入一个大小为 (3, 4) 的张量,其中包含了3个句子,每个句子由4个词组成。通过调用 `embedding(input)`,将输入张量中的每个词索引映射为对应的词嵌入向量,输出的张量大小为 `(3, 4, 300)`,即每个词都被映射为一个300维的词嵌入向量。
阅读全文