上课笔记

  • 导入数据
from sklearn import svm
import numpy as np
import sklearn

# 取数据集
from sklearn.datasets import load_iris
# 读入数据文件
load_data = load_iris()
# 导入数据和标签
data_X = load_data.data
data_X = data_X[:100,:2] # 为了后面的可视化,简化特征
data_Y = load_data.target
data_Y = data_Y[:100]
print(data_Y)
  • 划分样本数据和标签,抽取训练集
# 划分样本数据和标签 和 抽取训练集和测试集
train_data,test_data,train_label,test_label=sklearn.model_selection.train_test_split(data_X,data_Y,random_state=1,train_size=0.6,test_size=0.4)
print(train_data.shape)
print(test_data.shape)
  • 定义svm并训练
# 定义svm并训练
classifier=svm.SVC(C=2,kernel='rbf',gamma=10,decision_function_shape='ovr')
classifier.fit(train_data,train_label.ravel())

# 识别率
print('训练集:',classifier.score(train_data,train_label))
print('测试集:',classifier.score(test_data,test_label))
  • 绘制图形
#绘制图形
# 确定坐标轴范围
import matplotlib
import matplotlib.pyplot as plt

x1_min, x1_max = data_X[:, 0].min(), data_X[:, 0].max()  # 第0维特征的范围
x2_min, x2_max = data_X[:, 1].min(), data_X[:, 1].max()  # 第1维特征的范围
x1, x2 = np.mgrid[x1_min:x1_max:200j, x2_min:x2_max:200j]  # 生成网络采样点
grid_test = np.stack((x1.flat, x2.flat), axis=1)  # 测试点
# 指定默认字体
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
# 设置颜色
cm_light = matplotlib.colors.ListedColormap(['#A0FFA0', '#FFA0A0'])  
cm_dark = matplotlib.colors.ListedColormap(['g', 'r']) 
grid_hat = classifier.predict(grid_test)  # 预测分类值
grid_hat = grid_hat.reshape(x1.shape)  # 使之与输入的形状相同

plt.pcolormesh(x1, x2, grid_hat, cmap=cm_light)  # 预测值的显示
plt.scatter(data_X[:, 0], data_X[:, 1], c=data_Y, s=30, cmap=cm_dark)  # 原样本
plt.scatter(test_data[:, 0], test_data[:, 1], c=test_label, s=30, edgecolors='k', zorder=2,cmap=cm_dark)  # 圈中测试集样本点
plt.xlabel('花萼长度', fontsize=13)
plt.ylabel('花萼宽度', fontsize=13)
plt.xlim(x1_min, x1_max)
plt.ylim(x2_min, x2_max)
plt.title('鸢尾花SVM二特征分类')

plt.show()
print('ok')
最后修改:2021 年 11 月 21 日
如果觉得我的文章对你有用,请随意赞赏