initial commit
This commit is contained in:
commit
04f5361233
406
ML_P4_SVM.ipynb
Normal file
406
ML_P4_SVM.ipynb
Normal file
File diff suppressed because one or more lines are too long
3
utils/__init__.py
Normal file
3
utils/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .helper import plot_data, plot_centroids, plot_decision_boundaries, plot_svc_decision_boundary, make_xor, plot_predictions, plot_dataset
|
||||
|
||||
__all__ = ('plot_data', 'plot_centroids', 'plot_decision_boundaries', 'plot_svc_decision_boundary', 'make_xor', 'plot_predictions', 'plot_dataset')
|
103
utils/helper.py
Normal file
103
utils/helper.py
Normal file
@ -0,0 +1,103 @@
|
||||
# Python ≥3.5 is required
|
||||
import sys
|
||||
assert sys.version_info >= (3, 5)
|
||||
|
||||
# Common imports
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
# To plot pretty figures
|
||||
import matplotlib as mpl
|
||||
import matplotlib.pyplot as plt
|
||||
mpl.rc('axes', labelsize=14)
|
||||
mpl.rc('xtick', labelsize=12)
|
||||
mpl.rc('ytick', labelsize=12)
|
||||
|
||||
# Where to save the figures
|
||||
PROJECT_ROOT_DIR = "."
|
||||
CHAPTER_ID = "svm"
|
||||
IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID)
|
||||
os.makedirs(IMAGES_PATH, exist_ok=True)
|
||||
|
||||
|
||||
def plot_data(X):
|
||||
plt.plot(X[:, 0], X[:, 1], 'k.', markersize=2)
|
||||
|
||||
|
||||
def plot_centroids(centroids, weights=None, circle_color='w', cross_color='k'):
|
||||
if weights is not None:
|
||||
centroids = centroids[weights > weights.max() / 10]
|
||||
plt.scatter(centroids[:, 0], centroids[:, 1],
|
||||
marker='o', s=30, linewidths=8,
|
||||
color=circle_color, zorder=10, alpha=0.9)
|
||||
plt.scatter(centroids[:, 0], centroids[:, 1],
|
||||
marker='x', s=50, linewidths=50,
|
||||
color=cross_color, zorder=11, alpha=1)
|
||||
|
||||
|
||||
def plot_decision_boundaries(clusterer, X, resolution=1000, show_centroids=True,
|
||||
show_xlabels=True, show_ylabels=True):
|
||||
mins = X.min(axis=0) - 0.1
|
||||
maxs = X.max(axis=0) + 0.1
|
||||
xx, yy = np.meshgrid(np.linspace(mins[0], maxs[0], resolution),
|
||||
np.linspace(mins[1], maxs[1], resolution))
|
||||
Z = clusterer.predict(np.c_[xx.ravel(), yy.ravel()])
|
||||
Z = Z.reshape(xx.shape)
|
||||
|
||||
plt.contourf(Z, extent=(mins[0], maxs[0], mins[1], maxs[1]),
|
||||
cmap="Pastel2")
|
||||
plt.contour(Z, extent=(mins[0], maxs[0], mins[1], maxs[1]),
|
||||
linewidths=1, colors='k')
|
||||
plot_data(X)
|
||||
if show_centroids:
|
||||
plot_centroids(clusterer.cluster_centers_)
|
||||
|
||||
if show_xlabels:
|
||||
plt.xlabel("$x_1$", fontsize=14)
|
||||
else:
|
||||
plt.tick_params(labelbottom=False)
|
||||
if show_ylabels:
|
||||
plt.ylabel("$x_2$", fontsize=14, rotation=0)
|
||||
else:
|
||||
plt.tick_params(labelleft=False)
|
||||
|
||||
|
||||
def plot_svc_decision_boundary(svm_clf, xmin, xmax):
|
||||
w = svm_clf.coef_[0]
|
||||
b = svm_clf.intercept_[0]
|
||||
|
||||
# At the decision boundary, w0*x0 + w1*x1 + b = 0
|
||||
# => x1 = -w0/w1 * x0 - b/w1
|
||||
x0 = np.linspace(xmin, xmax, 200)
|
||||
decision_boundary = -w[0] / w[1] * x0 - b / w[1]
|
||||
|
||||
margin = 1 / w[1]
|
||||
gutter_up = decision_boundary + margin
|
||||
gutter_down = decision_boundary - margin
|
||||
|
||||
svs = svm_clf.support_vectors_
|
||||
plt.scatter(svs[:, 0], svs[:, 1], s=180, facecolors='#FFAAAA')
|
||||
plt.plot(x0, decision_boundary, "k-", linewidth=2)
|
||||
plt.plot(x0, gutter_up, "k--", linewidth=2)
|
||||
plt.plot(x0, gutter_down, "k--", linewidth=2)
|
||||
|
||||
def plot_predictions(clf, axes):
|
||||
x0s = np.linspace(axes[0], axes[1], 100)
|
||||
x1s = np.linspace(axes[2], axes[3], 100)
|
||||
x0, x1 = np.meshgrid(x0s, x1s)
|
||||
X = np.c_[x0.ravel(), x1.ravel()]
|
||||
y_pred = clf.predict(X).reshape(x0.shape)
|
||||
y_decision = clf.decision_function(X).reshape(x0.shape)
|
||||
plt.contourf(x0, x1, y_pred, cmap=plt.cm.brg, alpha=0.2)
|
||||
plt.contourf(x0, x1, y_decision, cmap=plt.cm.brg, alpha=0.1)
|
||||
|
||||
def plot_dataset(X, y):
|
||||
plt.plot(X[:, 0][y == 0], X[:, 1][y == 0], "bs")
|
||||
plt.plot(X[:, 0][y == 1], X[:, 1][y == 1], "g^")
|
||||
# plt.axis(axes)
|
||||
plt.grid(True, which='both')
|
||||
plt.xlabel(r"$x_1$", fontsize=20)
|
||||
plt.ylabel(r"$x_2$", fontsize=20, rotation=0)
|
||||
|
||||
def make_xor():
|
||||
pass
|
Loading…
x
Reference in New Issue
Block a user