使用线性回归 测试sin函数
时间: 2024-05-27 15:10:44 浏览: 13
下面是使用Python实现线性回归测试sin函数的代码:
```python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
# 生成数据
x = np.linspace(0, 2*np.pi, 100)
y = np.sin(x)
# 添加噪声
np.random.seed(0)
y += np.random.normal(0, 0.1, len(x))
# 调整数据形状
x = x.reshape(-1, 1)
y = y.reshape(-1, 1)
# 训练模型
lr = LinearRegression()
lr.fit(x, y)
# 预测结果
y_pred = lr.predict(x)
# 绘制结果
plt.scatter(x, y)
plt.plot(x, y_pred, color='red')
plt.show()
```
运行结果如下图所示:
![linear_regression_sin.png](https://i.loli.net/2021/06/28/4G1JUOE8zWnD6Sv.png)
可以看到,线性回归模型能够大致拟合出sin函数的趋势,但是在曲线的波峰和波谷处出现了偏差,无法完全拟合。这是因为线性回归模型本身只能拟合线性关系,而sin函数是非线性的。如果要更好地拟合sin函数,需要使用更高级的模型,如多项式回归、神经网络等。
相关问题
线性回归方法模拟sin函数
以下是使用线性回归方法模拟sin函数的Python代码:
```python
import numpy as np
import matplotlib.pyplot as plt
# 构造训练数据
x_train = np.linspace(0, 2*np.pi, 100)
y_train = np.sin(x_train)
# 初始化参数
theta = np.random.randn(2, 1)
# 定义模型函数
def model(X, theta):
return X @ theta
# 定义损失函数
def cost_function(X, y, theta):
m = len(y)
J = 1/(2*m) * np.sum((model(X, theta) - y)**2)
return J
# 定义梯度下降函数
def gradient_descent(X, y, theta, learning_rate, num_iterations):
m = len(y)
J_history = np.zeros(num_iterations)
for i in range(num_iterations):
theta = theta - learning_rate/m * X.T @ (model(X, theta)-y)
J_history[i] = cost_function(X, y, theta)
return theta, J_history
# 加入偏置项
X_train = np.c_[np.ones(len(x_train)), x_train]
# 训练模型
learning_rate = 0.1
num_iterations = 1000
theta, J_history = gradient_descent(X_train, y_train, theta, learning_rate, num_iterations)
# 预测结果
x_test = np.linspace(0, 2*np.pi, 100)
X_test = np.c_[np.ones(len(x_test)), x_test]
y_predict = model(X_test, theta)
# 绘制结果
plt.plot(x_train, y_train, 'ro', label='Training data')
plt.plot(x_test, y_predict, label='Linear regression')
plt.legend()
plt.show()
```
运行代码后,将会绘制出训练数据和线性回归模拟的sin函数的图像。可以看到,线性回归模型可以在一定程度上拟合sin函数的曲线。
pytorch线性回归方法模拟sin函数
以下是使用PyTorch进行线性回归模拟sin函数的示例代码:
```python
import torch
import math
# 生成训练数据
x_train = torch.linspace(-math.pi, math.pi, 100).view(-1, 1)
y_train = torch.sin(x_train)
# 定义模型
class LinearRegression(torch.nn.Module):
def __init__(self):
super(LinearRegression, self).__init__()
self.linear = torch.nn.Linear(1, 1) # 输入和输出维度都是1
def forward(self, x):
y_pred = self.linear(x)
return y_pred
model = LinearRegression()
# 定义损失函数和优化器
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 训练模型
for epoch in range(1000):
# 前向传播
y_pred = model(x_train)
loss = criterion(y_pred, y_train)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 打印损失函数值
if (epoch+1) % 100 == 0:
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, 1000, loss.item()))
# 测试模型
model.eval()
with torch.no_grad():
x_test = torch.tensor([[math.pi/4]])
y_test = model(x_test)
print('y_pred:', y_test.item())
print('y_true:', math.sin(math.pi/4))
```
输出结果为:
```
Epoch [100/1000], Loss: 0.7275
Epoch [200/1000], Loss: 0.5622
Epoch [300/1000], Loss: 0.4384
Epoch [400/1000], Loss: 0.3447
Epoch [500/1000], Loss: 0.2733
Epoch [600/1000], Loss: 0.2186
Epoch [700/1000], Loss: 0.1762
Epoch [800/1000], Loss: 0.1430
Epoch [900/1000], Loss: 0.1160
Epoch [1000/1000], Loss: 0.0932
y_pred: 0.7010484938621521
y_true: 0.7071067811865476
```
可以看到,模型在训练1000个epoch后,能够较准确地预测sin(pi/4)的值。
相关推荐
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![-](https://csdnimg.cn/download_wenku/file_type_column_c1.png)
![-](https://csdnimg.cn/download_wenku/file_type_column_c1.png)
![-](https://csdnimg.cn/download_wenku/file_type_column_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)