import torch def AddDataToTorch(oriData, newData, addDim): if oriData is None: return newData else: if (len(oriData.shape) > addDim) and (len(newData.shape) > addDim): return torch.cat((oriData, newData), dim = addDim) else: return torch.cat((oriData, newData), dim = 1) mergeData = None tmpData = torch.randn(3) print(tmpData.shape) for i in range(1, 10): print('indx = ') print(i) mergeData = AddDataToTorch(mergeData, tmpData, 0) print(mergeData.shape) 这段代码如何修改能输出9行3列的torch
时间: 2023-06-26 13:08:36 浏览: 133
可以将代码修改为以下形式:
```
import torch
def AddDataToTorch(oriData, newData, addDim):
if oriData is None:
return newData
else:
if (len(oriData.shape) > addDim) and (len(newData.shape) > addDim):
return torch.cat((oriData, newData), dim = addDim)
else:
return torch.cat((oriData, newData), dim = 1)
mergeData = None
tmpData = torch.randn(1, 3) # 修改这里,生成1行3列的随机数张量
print(tmpData.shape)
for i in range(1, 10):
print('indx = ')
print(i)
mergeData = AddDataToTorch(mergeData, tmpData, 0)
print(mergeData.shape) # 输出9行3列的张量
```
将 `tmpData` 的随机数张量修改为 1 行 3 列的张量,然后在循环中使用 `AddDataToTorch` 方法将其添加到 `mergeData` 中,最后输出结果张量 `mergeData` 的形状即可。
阅读全文