nn.Embedding 的输入跟输出
时间: 2024-01-30 11:11:27 浏览: 135
nn.Embedding是PyTorch中的一个模块,用于将离散的整数映射到连续的向量空间中。它的输入是一个整数张量,输出是对应的向量表示。
具体来说,nn.Embedding的输入是一个大小为(batch_size, seq_length)的整数张量,其中batch_size表示批次大小,seq_length表示序列长度。每个整数值代表一个离散的类别或标签。
nn.Embedding的输出是一个大小为(batch_size, seq_length, embedding_dim)的张量,其中embedding_dim表示嵌入向量的维度。输出张量的每个元素是对应输入整数值的嵌入向量表示。
下面是一个示例代码,演示了nn.Embedding的输入和输出:
```python
import torch
import torch.nn as nn
# 定义一个Embedding层
embedding = nn.Embedding(10, 3) # 输入大小为10,嵌入维度为3
# 定义输入张量
input_tensor = torch.LongTensor([[1, 2, 3], [4, 5, 6]])
# 使用Embedding层进行嵌入
output_tensor = embedding(input_tensor)
print("输入张量大小:", input_tensor.size()) # 输出:torch.Size([2, 3])
print("输出张量大小:", output_tensor.size()) # 输出:torch.Size([2, 3, 3])
```
在上面的示例中,输入张量的大小是(2, 3),表示有两个样本,每个样本有三个整数值。输出张量的大小是(2, 3, 3),表示有两个样本,每个样本有三个嵌入向量,每个嵌入向量的维度是3。
阅读全文