pytorch embedding demo code
时间: 2023-10-16 17:07:30 浏览: 76
pytorch测试代码
Here is a simple PyTorch embedding demo code:
```
import torch
import torch.nn as nn
# Define some input data
input_data = torch.LongTensor([[1, 2, 3], [4, 5, 6]])
# Define the embedding layer
embedding_layer = nn.Embedding(num_embeddings=10, embedding_dim=4)
# Pass the input data through the embedding layer
embedded_data = embedding_layer(input_data)
# Print the output
print(embedded_data)
```
In this code, we first define some input data as a LongTensor with shape (2, 3). We then define an embedding layer using nn.Embedding, specifying the number of embeddings (num_embeddings) and the dimension of each embedding (embedding_dim). We then pass the input_data through the embedding layer using embedding_layer(input_data), which returns the embedded_data with shape (2, 3, 4), where the last dimension corresponds to the embedding dimension. Finally, we print the output.
阅读全文