机器学习之SVM实战鸢尾花数据二分类
【摘要】 SVM的线性核函数和高斯核函数比较
在SVM中利用对线性核函数和高斯核函数进行调参来对二分类鸢尾花数据分类结果进行比较。从最终的结果来看,调参是很重要的。
其代码如下:
import numpy as np from sklearn import svm import matplotlib as mpl import matplotlib.colors import matplotlib.pyplot as plt def show_accuracy(a, b): acc = a.ravel() == b.ravel() # print '正确率:%.2f%%' % (100*float(acc.sum()) / a.size) if __name__ == "__main__": data = np.loadtxt('bipartition.txt', dtype=np.float, delimiter='\t') x, y = np.split(data, (2, ), axis=1) y = y.ravel() clf_param = (('linear', 0.1), ('linear', 0.5), ('linear', 1), ('linear', 2), ('rbf', 1, 0.1), ('rbf', 1, 1), ('rbf', 1, 10), ('rbf', 1, 100), ('rbf', 5, 0.1), ('rbf', 5, 1), ('rbf', 5, 10), ('rbf', 5, 100)) x1_min, x1_max = x[:, 0].min(), x[:, 0].max() # 第0列的范围 x2_min, x2_max = x[:, 1].min(), x[:, 1].max() # 第1列的范围 x1, x2 = np.mgrid[x1_min:x1_max:200j, x2_min:x2_max:200j] # 生成网格采样点 grid_test = np.stack((x1.flat, x2.flat), axis=1) # 测试点 cm_light = mpl.colors.ListedColormap(['#77E0A0', '#FFA0A0']) cm_dark = mpl.colors.ListedColormap(['g', 'r']) mpl.rcParams['font.sans-serif'] = [u'SimHei'] mpl.rcParams['axes.unicode_minus'] = False plt.figure(figsize=(14, 10), facecolor='w') for i, param in enumerate(clf_param): clf = svm.SVC(C=param[1], kernel=param[0]) if param[0] == 'rbf': clf.gamma = param[2] title = u'高斯核,C=%.1f,$\gamma$ =%.1f' % (param[1], param[2]) else: title = u'线性核,C=%.1f' % param[1] clf.fit(x, y) y_hat = clf.predict(x) show_accuracy(y_hat, y) # 准确率 print(title) print('支撑向量的数目:', clf.n_support_) print ('支撑向量的系数:', clf.dual_coef_) print ('支撑向量:', clf.support_) grid_hat = clf.predict(grid_test) # 预测分类值 grid_hat = grid_hat.reshape(x1.shape) # 使之与输入的形状相同 z = clf.decision_function(grid_test) #print('clf.decision_function(x) = ', clf.decision_function(x)) #print('clf.predict(x) = ', clf.predict(x)) z = z.reshape(x1.shape)
得到结果示意图如下:
【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)