Skip to content

Commit eb27bf9

Browse files
committed
📝 some updates
1 parent d3d71d8 commit eb27bf9

File tree

2 files changed

+43
-44
lines changed

2 files changed

+43
-44
lines changed

LogisticRegression/LogisticRegression_scikit-learn.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from sklearn.linear_model import LogisticRegression
44
from sklearn.preprocessing import StandardScaler
5-
from sklearn.cross_validation import train_test_split
5+
# from sklearn.cross_validation import train_test_split # 0.18版本之后废弃
6+
from sklearn.model_selection import train_test_split
67
import numpy as np
78

89
def logisticRegression():

SVM/SVM_scikit-learn.py

+41-43
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,69 @@
1+
#-*- coding: utf-8 -*-
12
import numpy as np
23
from scipy import io as spio
34
from matplotlib import pyplot as plt
45
from sklearn import svm
56

7+
68
def SVM():
7-
'''data1——线性分类'''
9+
'''data1——线性分类'''
810
data1 = spio.loadmat('data1.mat')
911
X = data1['X']
1012
y = data1['y']
1113
y = np.ravel(y)
12-
plot_data(X,y)
13-
14-
model = svm.SVC(C=1.0,kernel='linear').fit(X,y) # 指定核函数为线性核函数
15-
plot_decisionBoundary(X, y, model) # 画决策边界
16-
'''data2——非线性分类'''
14+
plot_data(X, y)
15+
16+
model = svm.SVC(C=1.0, kernel='linear').fit(X, y) # 指定核函数为线性核函数
17+
plot_decisionBoundary(X, y, model) # 画决策边界
18+
'''data2——非线性分类'''
1719
data2 = spio.loadmat('data2.mat')
1820
X = data2['X']
1921
y = data2['y']
2022
y = np.ravel(y)
21-
plt = plot_data(X,y)
23+
plt = plot_data(X, y)
2224
plt.show()
23-
24-
model = svm.SVC(gamma=100).fit(X,y) # gamma为核函数的系数,值越大拟合的越好
25-
plot_decisionBoundary(X, y, model,class_='notLinear') # 画决策边界
26-
27-
28-
29-
# 作图
30-
def plot_data(X,y):
31-
plt.figure(figsize=(10,8))
32-
pos = np.where(y==1) # 找到y=1的位置
33-
neg = np.where(y==0) # 找到y=0的位置
34-
p1, = plt.plot(np.ravel(X[pos,0]),np.ravel(X[pos,1]),'ro',markersize=8)
35-
p2, = plt.plot(np.ravel(X[neg,0]),np.ravel(X[neg,1]),'g^',markersize=8)
25+
26+
model = svm.SVC(gamma=100).fit(X, y) # gamma为核函数的系数,值越大拟合的越好
27+
plot_decisionBoundary(X, y, model, class_='notLinear') # 画决策边界
28+
29+
30+
# 作图
31+
def plot_data(X, y):
32+
plt.figure(figsize=(10, 8))
33+
pos = np.where(y == 1) # 找到y=1的位置
34+
neg = np.where(y == 0) # 找到y=0的位置
35+
p1, = plt.plot(np.ravel(X[pos, 0]), np.ravel(X[pos, 1]), 'ro', markersize=8)
36+
p2, = plt.plot(np.ravel(X[neg, 0]), np.ravel(X[neg, 1]), 'g^', markersize=8)
3637
plt.xlabel("X1")
3738
plt.ylabel("X2")
38-
plt.legend([p1,p2],["y==1","y==0"])
39+
plt.legend([p1, p2], ["y==1", "y==0"])
3940
return plt
40-
41-
# 画决策边界
42-
def plot_decisionBoundary(X,y,model,class_='linear'):
41+
42+
43+
# 画决策边界
44+
def plot_decisionBoundary(X, y, model, class_='linear'):
4345
plt = plot_data(X, y)
44-
45-
# 线性边界
46-
if class_=='linear':
46+
47+
# 线性边界
48+
if class_ == 'linear':
4749
w = model.coef_
4850
b = model.intercept_
49-
xp = np.linspace(np.min(X[:,0]),np.max(X[:,0]),100)
50-
yp = -(w[0,0]*xp+b)/w[0,1]
51-
plt.plot(xp,yp,'b-',linewidth=2.0)
51+
xp = np.linspace(np.min(X[:, 0]), np.max(X[:, 0]), 100)
52+
yp = -(w[0, 0] * xp + b) / w[0, 1]
53+
plt.plot(xp, yp, 'b-', linewidth=2.0)
5254
plt.show()
53-
else: # 非线性边界
54-
x_1 = np.transpose(np.linspace(np.min(X[:,0]),np.max(X[:,0]),100).reshape(1,-1))
55-
x_2 = np.transpose(np.linspace(np.min(X[:,1]),np.max(X[:,1]),100).reshape(1,-1))
56-
X1,X2 = np.meshgrid(x_1,x_2)
55+
else: # 非线性边界
56+
x_1 = np.transpose(np.linspace(np.min(X[:, 0]), np.max(X[:, 0]), 100).reshape(1, -1))
57+
x_2 = np.transpose(np.linspace(np.min(X[:, 1]), np.max(X[:, 1]), 100).reshape(1, -1))
58+
X1, X2 = np.meshgrid(x_1, x_2)
5759
vals = np.zeros(X1.shape)
5860
for i in range(X1.shape[1]):
59-
this_X = np.hstack((X1[:,i].reshape(-1,1),X2[:,i].reshape(-1,1)))
60-
vals[:,i] = model.predict(this_X)
61-
62-
plt.contour(X1,X2,vals,[0,1],color='blue')
61+
this_X = np.hstack((X1[:, i].reshape(-1, 1), X2[:, i].reshape(-1, 1)))
62+
vals[:, i] = model.predict(this_X)
63+
64+
plt.contour(X1, X2, vals, [0, 1], color='blue')
6365
plt.show()
64-
6566

6667

6768
if __name__ == "__main__":
68-
SVM()
69-
70-
71-
69+
SVM()

0 commit comments

Comments
 (0)