python中unsqueeze函数用法
时间: 2024-06-12 10:09:02 浏览: 158
详解pytorch中squeeze()和unsqueeze()函数介绍
unsqueeze函数是PyTorch库中的一个函数,用于在给定索引处插入新的维度。它的语法是torch.unsqueeze(tensor, dim),其中tensor是输入的张量,dim是要插入新维度的索引。例如,当dim=0时,unsqueeze函数将在张量的最外层插入一个新的维度。unsqueeze函数返回一个新的张量。
以下是一个示例:
import torch
x = torch.ones(4)
print(x.size()) # 输出torch.Size()
y = torch.unsqueeze(x, 0)
print(y.size()) # 输出torch.Size([1, 4])
在这个示例中,初始张量x的大小是,使用unsqueeze函数在索引0处插入一个新的维度,得到的新张量y的大小是[1, 4]。
阅读全文