Browse Source

L1L2正则

yufeng0528 4 years ago
parent
commit
dd98df8dca
1 changed files with 46 additions and 0 deletions
  1. 46 0
      logistic/cancer_train_l1l2.py

+ 46 - 0
logistic/cancer_train_l1l2.py

@@ -0,0 +1,46 @@
1
+# -*- encoding:utf-8 -*-
2
+from sklearn import datasets
3
+from sklearn.linear_model import LogisticRegression
4
+from numpy import shape
5
+from sklearn import metrics
6
+import numpy as np
7
+
8
+
9
+def read_data(path):
10
+    with open(path) as f:
11
+        lines = f.readlines()
12
+    lines = [eval(line.strip()) for line in lines]
13
+    X, y = zip(*lines)
14
+    X = np.array(X)
15
+    y = np.array(y)
16
+    return X, y
17
+
18
+
19
+X_train, y_train = read_data("cancer_train_data")
20
+X_test, y_test = read_data("cancer_test_data")
21
+
22
+
23
+def train_model(reg):
24
+    print reg
25
+    model = LogisticRegression(penalty=reg)
26
+    model.fit(X_train, y_train)
27
+    print "w", model.coef_
28
+    # print (model.intercept_)
29
+    y_pred_train = model.predict(X_train)
30
+    y_pred_test = model.predict(X_test)
31
+    e_train = metrics.mean_squared_error(y_train, y_pred_train)
32
+    e_test = metrics.mean_squared_error(y_test, y_pred_test)
33
+
34
+    kl_train = metrics.log_loss(y_train, y_pred_train)
35
+    kl_test = metrics.log_loss(y_test, y_pred_test)
36
+
37
+    print "训练集MSE:{}, KL:{}".format(e_train, kl_train)
38
+    print "测试集MSE:{}, KL:{}".format(e_test, kl_test)
39
+    print "训练测试差异{}".format(e_test-e_train)
40
+    print
41
+
42
+
43
+
44
+# train_model(reg="None")
45
+train_model(reg="l1")
46
+train_model(reg="l2")