data = pd.read_csv("data.csv") data.replace("M",1,inplace=True) data.replace("B",0,inplace=True) #获取特征x和特征y X = data.iloc[:, 3:5].values x = np.array(X) y = data.diagnosis y = np.array(y) #创建决策树算法对象 tree_clf = DecisionTreeClassifier(max_depth=2) #构建决策树 tree_clf.fit(x,y) #绘制决策树结构 tree.plot_tree(tree_clf) from matplotlib.colors import ListedColormap plt.rcParams["font.sans-serif"] = ["SimHei"] plt.rcParams["axes.unicode_minus"] = False #定义绘制决策树边界的函数 def plot_decision_boundary(clf, X, y, axes=[0, 10 , 0 , 5], data=True, legend=False, plot_training=True): x1s = np.linspace(axes[0], axes[1], 100) x2s = np.linspace(axes[2], axes[3], 100) x1, x2 = np.meshgrid(x1s, x2s) X_new = np.c_[x1.ravel(), x2.ravel()] y_pred = clf.predict(X_new).reshape(x1.shape) custom_cmap = ListedColormap(['#fafab0', '#0909ff', '#a0faa0']) plt.contourf(x1, x2, y_pred, alpha=0.3, cmap=custom_cmap) if not data: custom_cmap2 = ListedColormap(['#7d7d58', '#4c4c7f', '#507d50']) plt.contour(x1, x2, y_pred, cmap=custom_cmap2, alpha=0.8) if plot_training: plt.plot(X[:, 0][y == 0], X[:, 1][y == 0], "yo", label="0") plt.plot(X[:, 0][y == 1], X[:, 1][y == 1],"bs", label="1") plt.axis(axes) if data: plt.xlabel("属性",fontsize=14) plt.ylabel("特征",fontsize=14) else: plt.xlabel(r"$x_1$", fontsize=18) plt.xlabel(r"$x_2$", fontsize=18,rotation=0) if legend: plt.legend(loc="lower right", fontsize=14) tree_clf1 = DecisionTreeClassifier(random_state=42) tree_clf2 = DecisionTreeClassifier(min_samples_leaf=4,random_state=43) tree_clf1.fit(x,y) tree_clf2.fit(x,y) plt.figure(figsize=(15,6)) plt.subplot(121) plot_decision_boundary(tree_clf1, x, y, axes=[0, 40, 50, 150], data=False) plt.title('圖一') plt.subplot(122) plot_decision_boundary(tree_clf2, x, y, axes=[0, 40, 50, 150], data=False) plt.title('圖二')
时间: 2024-04-02 12:31:14 浏览: 109
这段代码使用了决策树算法对数据进行分类,并绘制了决策树的结构以及决策边界。其中,数据需要先进行预处理,将"M"替换成1,"B"替换成0。然后使用特征x和特征y进行分类,其中x取data的第3到第5列,y取data的diagnosis列。接着,创建决策树对象,并使用fit()方法进行训练。最后使用plot_decision_boundary()函数绘制决策树的结构和决策边界。该函数会根据传入的决策树模型,数据特征和标签进行绘制,可以进行分类的数据点用蓝色正方形表示,不可分类的数据点用黄色圆圈表示。其中,图一的决策树没有设置最小叶子节点样本数,图二的决策树设置了最小叶子节点样本数为4。
相关问题
import numpy as np import pandas as pd train_data = pd.read_csv("C://Users//Dell//Desktop//数据分析作业//adult_train(1).csv") test_data = pd.read_csv("C://Users//Dell//Desktop//数据分析作业//adult_test.csv") #写入csv文件 columns = ['Age','Workclass','fnlgwt','Education','EdNum','MaritalStatus', 'Occupation','Relationship','Race','Sex','CapitalGain', 'CapitalLoss','HoursPerWeek','Country','Income'] #写入名称 df_train_set = pd.read_csv('C://Users//Dell//Desktop//数据分析作业//adult_train(1).csv', names=columns) #跳过表头 df_test_set = pd.read_csv('C://Users//Dell//Desktop//数据分析作业//adult_test.csv', names=columns, skiprows=1) #删除ID序列号栏 df_train_set.drop('fnlgwt', axis=1, inplace=True) df_test_set.drop('fnlgwt', axis=1, inplace=True) #用unknown替换? for i in df_train_set.columns: df_train_set[i].replace('?', 'Unknown', inplace=True) df_test_set[i].replace('?', 'Unknown', inplace=True) #去掉非int64类型数据中的点和空格 for col in df_train_set.columns: if df_train_set[col].dtype != 'int64': df_train_set[col] = df_train_set[col].apply(lambda val: val.replace(" ", "")) df_train_set[col] = df_train_set[col].apply(lambda val: val.replace(".", "")) df_test_set[col] = df_test_set[col].apply(lambda val: val.replace(" ", "")) df_test_set[col] = df_test_set[col].apply(lambda val: val.replace(".", "")) # 将训练集导出为 csv 文件 df_train_set.to_csv('train_set.csv', index=False) # 将测试集导出为 csv 文件 df_test_set.to_csv('test_set.csv', index=False)
这段代码是用来读取和预处理成人收入数据集(Adult dataset)的。其中包括读取训练集和测试集的csv文件,去掉ID序列号栏,用"Unknown"替换缺失值"?",去除非int64类型数据中的点和空格,并将处理后的数据集导出为csv文件。该数据集是一个二分类问题,目标是预测一个人的收入是否超过50K美元。这段代码中的预处理是为了对数据进行清洗和转换,以便更好地进行建模和预测。
import pandas as pd data= pd.read_csv('C:\\Users\\admin\\Desktop\\taobao.product2.csv', encoding='utf8',engine='python') data.drop('_id',axis=1, inplace=True)#去掉id列 data['price'].replace('¥', '', regex=True, inplace=True) #去掉price列的'¥' #从location列中取出省份和城市,然后删除location列 data['province']=data.location.apply(lambda x:x.split()[0]) data['city']=data.location.apply(lambda x:x.split()[0] if len(x)<4 else x.split()[1]) data.drop('location',axis=1, inplace=True) #数据类型转化 data['price']=data.price.astype('float64') for i in ['province','city']: data[i]=data[i].astype('category')
这段代码是用来读取一个名为taobao.product2.csv的CSV文件,并对数据进行一些处理操作。
首先,使用pandas库导入数据,并指定文件路径和编码格式。然后,通过设置engine='python'来使用python解析CSV文件。
接下来,使用drop函数去掉名为'_id'的列。使用replace函数去掉'price'列中的'¥'符号。
然后,从'location'列中提取省份和城市信息,并将其分别存储在新的'province'和'city'列中。如果'location'的长度小于4,则省份和城市相同;否则,取第二个元素作为城市。
最后,使用astype函数将'price'、'province'和'city'的数据类型转换为float64和category。
请问还有其他问题吗?
阅读全文