Y_train_binary = torch.FloatTensor( label_binarize(Y_train, classes=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) )怎么改为四维
时间: 2023-10-22 11:09:50 浏览: 111
PyTorch中Tensor的维度变换实现
张量?
你可以使用 `torch.unsqueeze()` 方法将该张量的维度扩充为四维。假设原来的张量形状为 `(batch_size, num_classes)`,你可以按以下方式将其转换为四维张量:
```
Y_train_binary = torch.unsqueeze(Y_train_binary, dim=0) # 在第0维扩充一个维度
Y_train_binary = torch.unsqueeze(Y_train_binary, dim=2) # 在第2维扩充一个维度
Y_train_binary = torch.unsqueeze(Y_train_binary, dim=3) # 在第3维扩充一个维度
```
这样就得到了形状为 `(1, batch_size, 1, 1, num_classes)` 的四维张量。
阅读全文