本篇文章延續上一篇:Python實作各種水果分類器那一種最好?
程式原始碼:https://github.com/susanli2016/Machine-Learning-with-Python/blob/master/Solving%20A%20Simple%20Classification%20Problem%20with%20Python.ipynb水果資料來源:https://github.com/susanli2016/Machine-Learning-with-Python/blob/master/fruit_data_with_colors.txt
在上一篇文章中,我們發現最近鄰居法(KNN)在水果分類上表現最好,本篇將原始程整理後列出,並以matplotlib套件畫出水果分類圖,在原始程式中有一點小小錯誤,在第25行,已修復,讀者不妨試試。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | import numpy as np import pandas as pd import matplotlib.pyplot as plt import matplotlib.cm as cm from matplotlib.colors import ListedColormap, BoundaryNorm import matplotlib.patches as mpatches import matplotlib.patches as mpatches from sklearn.model_selection import train_test_split from sklearn.neighbors import KNeighborsClassifier fruits = pd.read_table('fruit_data_with_colors.txt') X = fruits[['mass', 'width', 'height', 'color_score']] y = fruits['fruit_label'] X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) def plot_fruit_knn(X, y, n_neighbors, weights): X_mat = X[['height', 'width']].as_matrix() y_mat = y.as_matrix() # Create color maps cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF','#AFAFAF']) cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF','#AFAFAF']) clf = KNeighborsClassifier(n_neighbors, weights=weights) clf.fit(X_mat, y_mat) # Plot the decision boundary by assigning a color in the color map # to each mesh point. mesh_step_size = .01 # step size in the mesh plot_symbol_size = 50 x_min, x_max = X_mat[:, 0].min() - 1, X_mat[:, 0].max() + 1 y_min, y_max = X_mat[:, 1].min() - 1, X_mat[:, 1].max() + 1 xx, yy = np.meshgrid(np.arange(x_min, x_max, mesh_step_size), np.arange(y_min, y_max, mesh_step_size)) Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) # Put the result into a color plot Z = Z.reshape(xx.shape) plt.figure() plt.pcolormesh(xx, yy, Z, cmap=cmap_light) # Plot training points plt.scatter(X_mat[:, 0], X_mat[:, 1], s=plot_symbol_size, c=y, cmap=cmap_bold, edgecolor = 'black') plt.xlim(xx.min(), xx.max()) plt.ylim(yy.min(), yy.max()) patch0 = mpatches.Patch(color='#FF0000', label='apple') patch1 = mpatches.Patch(color='#00FF00', label='mandarin') patch2 = mpatches.Patch(color='#0000FF', label='orange') patch3 = mpatches.Patch(color='#AFAFAF', label='lemon') plt.legend(handles=[patch0, patch1, patch2, patch3]) plt.xlabel('height (cm)') plt.ylabel('width (cm)') plt.title("4-Class classification (k = %i, weights = '%s')" % (n_neighbors, weights)) plt.show() plot_fruit_knn(X_train, y_train, 5, 'uniform') |
沒有留言:
張貼留言