把plt.subplot()用函数封装起来
时间: 2023-10-13 13:16:07 浏览: 162
可以使用Python中的函数来封装plt.subplot(),以下是示例代码:
```python
import matplotlib.pyplot as plt
def create_subplot(rows, cols, index):
plt.subplot(rows, cols, index)
# 示例用法
create_subplot(2, 2, 1)
plt.plot([1, 2, 3, 4], [1, 4, 2, 3])
create_subplot(2, 2, 2)
plt.plot([1, 2, 3, 4], [4, 2, 3, 1])
create_subplot(2, 1, 2)
plt.plot([1, 2, 3, 4], [1, 2, 3, 4])
plt.show()
```
在这个示例中,我们定义了一个名为`create_subplot()`的函数,它接受三个参数:`rows`表示子图的行数,`cols`表示子图的列数,`index`表示当前子图的索引。在函数内部,我们调用了plt.subplot()函数,并将三个参数传递给它。
在示例中,我们依次调用了三次create_subplot()函数,分别创建了三个子图,并在每个子图中绘制了一条曲线。最后,我们调用plt.show()函数来显示图形。
相关问题
优化这段代码import numpy as np from scipy.spatial.distance import cdist from sklearn.model_selection import train_test_split from sklearn.neighbors import KNeighborsClassifier from sklearn.metrics import accuracy_score import pandas as pd # 导入pd库读取文件 import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D #绘制3D图 # 读取txt文件做数据集 D_path = r"G:\Pycharm\pythonProject1\HomeWork2 for KNN.txt" # 通过read_csv读取txt文件的内容 data_set = pd.read_csv(D_path, sep=" ", engine='python', index_col=False, names=["行驶公里数", "售价", "油耗", "喜爱程度"]) saved_path = "D:/" # 将标签对应数值 y_num = ({"didntLike": 0, "smallDoses": 1, "largeDoses": 2}) data_set["喜爱程度"] = data_set["喜爱程度"].map(y_num) X = data_set[["行驶公里数", "售价", "油耗"]] y = data_set["喜爱程度"] X_train,X_test,y_train,y_test = train_test_split(X, y, test_size=0.33, shuffle=True) knn = KNeighborsClassifier(algorithm="kd_tree") knn.fit(X_train, y_train) pred = knn.predict(X_test) print("预测精度:{:.2%}".format(accuracy_score(pred, y_test))) #------------------3D图----------------------# fig = plt.figure(figsize=(18,12), facecolor='lightgray') ax = fig.add_subplot(111,projection='3d') # 行数:1, 列数:1, 位置:1 ax.scatter(X_test["行驶公里数"], X_test["售价"], X_test["油耗"], c=pred) plt.savefig(saved_path+ "3D" + ".jpg") plt.show()
1. 避免使用绝对路径,可以使用相对路径来读取文件,这样代码更具有可移植性。
2. 通过使用pandas库中的read_csv函数读取txt文件,可以省去使用scipy库中的cdist函数计算欧式距离的步骤,从而简化代码。
3. 在生成3D图之前,应该先将测试集中的数据进行降维处理,否则3D图会非常密集,难以观察。
4. 可以将预测精度和3D图分别封装成函数,使代码更加清晰和易于维护。
下面是优化后的代码:
```python
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
def load_data(file_path):
"""
读取数据集
"""
data_set = pd.read_csv(file_path, sep="\s+", header=None, names=["行驶公里数", "售价", "油耗", "喜爱程度"])
y_num = {"didntLike": 0, "smallDoses": 1, "largeDoses": 2}
data_set["喜爱程度"] = data_set["喜爱程度"].map(y_num)
X = data_set[["行驶公里数", "售价", "油耗"]]
y = data_set["喜爱程度"]
return X, y
def knn_predict(X_train, X_test, y_train, k=5):
"""
使用KNN算法进行预测,并返回预测结果和预测精度
"""
knn = KNeighborsClassifier(n_neighbors=k, algorithm="kd_tree")
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)
acc = accuracy_score(y_pred, y_test)
return y_pred, acc
def plot_3D(X, y_pred):
"""
绘制3D图
"""
fig = plt.figure(figsize=(18,12), facecolor='lightgray')
ax = fig.add_subplot(111,projection='3d')
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=y_pred)
plt.savefig("3D.jpg")
plt.show()
if __name__ == '__main__':
file_path = os.path.join(os.getcwd(), "HomeWork2 for KNN.txt")
X, y = load_data(file_path)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, shuffle=True)
y_pred, acc = knn_predict(X_train, X_test, y_train)
print("预测精度:{:.2%}".format(acc))
plot_3D(X_test.values, y_pred)
```
这段代码已经对原始代码进行了优化:
1. 使用相对路径来读取文件。
2. 使用pandas库中的read_csv函数读取txt文件,从而不需要使用scipy库中的cdist函数计算欧式距离。
3. 在生成3D图之前,先将测试集中的数据进行PCA降维处理。
4. 将预测精度和3D图分别封装成函数,使代码更加清晰和易于维护。
我有一个字典,键是时间,值是三个原子的三维坐标,我希望用matplotlib将其可视化,绘制为三维动画,并将这些函数封装在一个类方法中,即输入一个字典,可以为我生成动画
为了将字典中的时间-三维坐标数据用matplotlib创建为三维动画,并封装成一个类方法,你可以按照以下步骤操作:
首先,你需要安装`matplotlib.animation`库,如果还没有安装,可以使用pip安装:
```bash
pip install matplotlib numpy
```
然后,你可以创建一个名为`AnimationVisualizer`的类,这个类包含一个方法`plot_3d_animation`,接受字典作为输入。这里是一个简单的示例:
```python
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.animation import FuncAnimation
class AnimationVisualizer:
def __init__(self):
self.fig = plt.figure()
self.ax = fig.add_subplot(111, projection='3d')
self.data = None # 这里需要你在实例化时传入你的字典
def plot_3d_animation(self, data_dict):
if not self.data:
self.data = [(data_dict[t], t) for t in sorted(data_dict.keys())] # 将字典按时间排序并打包为列表
x, y, z = zip(*[coords for _, coords in self.data]) # 提取x, y, z坐标
self.line, = self.ax.plot(x, y, z, lw=2)
def update(frame):
index = frame % len(self.data)
self.line.set_data(*zip(*self.data[index]))
return self.line,
ani = FuncAnimation(self.fig, update, frames=np.arange(len(self.data)), interval=50, blit=True)
plt.show()
# 使用示例
my_dict = {t: (x, y, z) for t, (x, y, z) in enumerate(zip(range(10), range(10, 20), range(20, 30)))} # 示例字典
animator = AnimationVisualizer()
animator.plot_3d_animation(my_dict)
```
在这个例子中,我们创建了一个`update`函数作为`FuncAnimation`的核心,它会根据帧数更新动画中的线图。当你调用`plot_3d_animation`方法时,你需要提供一个字典,其键对应时间,值为三个坐标的元组。
阅读全文