torch 简单使用
时间: 2023-10-21 08:30:03 浏览: 40
Torch 是一个基于 Lua 语言的机器学习框架,它提供了一个丰富的工具箱,使研究人员和开发人员能够轻松地构建和训练复杂的神经网络。
以下是 Torch 的简单使用示例:
1. 安装 Torch
可以在官方网站上下载并安装 Torch,也可以使用以下命令在 Linux 和 Mac 上安装:
```
curl -s https://raw.githubusercontent.com/torch/ezinstall/master/install-deps | bash
git clone https://github.com/torch/distro.git ~/torch --recursive
cd ~/torch; bash install.sh
source ~/.bashrc
```
2. 使用 Torch 进行线性回归
下面是一个使用 Torch 进行线性回归的简单示例:
```lua
-- 导入 Torch 库
require 'torch'
-- 随机生成训练数据
x = torch.randn(100, 1)
y = 3 * x + torch.randn(100, 1) * 0.1
-- 定义模型
model = nn.Linear(1, 1)
-- 定义损失函数
criterion = nn.MSECriterion()
-- 定义优化器
optimizer = optim.SGD(model:getParameters(), 0.1)
-- 训练模型
for i = 1, 1000 do
-- 前向传播
y_hat = model:forward(x)
-- 计算损失
loss = criterion:forward(y_hat, y)
-- 反向传播
gradient = criterion:backward(y_hat, y)
model:backward(x, gradient)
-- 更新参数
optimizer:update()
-- 打印损失
if i % 100 == 0 then
print('loss:', loss)
end
end
-- 预测
x_test = torch.Tensor{0.5}
y_test = model:forward(x_test)
print('y_test:', y_test[1])
```
以上代码生成了一个大小为 100 的随机训练数据集,使用线性模型进行拟合,并输出了预测结果。