yufeng0528 4 years ago
parent
commit
409d4863a6
1 changed files with 37 additions and 3 deletions
  1. 37 3
      linear/gradient_linear.py

+ 37 - 3
linear/gradient_linear.py

@@ -26,20 +26,54 @@ def cal_step_pow(data, w, b=3):
26 26
     sum_p = sum(p)
27 27
     return sum_p/len(data)
28 28
 
29
-def cal_mse(data, w, b=3):
29
+
30
+def cal_step_pow_b(data, w, b=1):
31
+    p = [(w*item[0][0] + b - item[1][0])*2  for item in data]
32
+    sum_p = sum(p)
33
+    return sum_p/len(data)
34
+
35
+
36
+def cal_mse(data, w, b=1):
30 37
     sum_p = sum([(w * item[0][0] + b - item[1][0]) * (w * item[0][0] + b - item[1][0]) for item in data])
31 38
     return sum_p / len(data)
32 39
 
40
+
33 41
 def train():
34 42
     train_data = read_data('train_data')
35 43
     w = random.uniform(-50, 50)
36
-    for i in range(10):
44
+    for i in range(50):
37 45
         step = cal_step_pow(train_data, w)*0.01
38 46
         mse = cal_mse(train_data, w)
39 47
         print w, step, mse
40 48
 
41 49
         w = w - step
50
+    return w
51
+
52
+
53
+def train_b(w):
54
+    train_data = read_data('train_data')
55
+    b = random.uniform(-50, 50)
56
+    for i in range(1000):
57
+        step = cal_step_pow_b(train_data, w, b)*0.01
58
+        mse = cal_mse(train_data, w, b)
59
+        print b, step, mse
60
+
61
+        b = b - step
62
+    return b
42 63
 
43 64
 
44 65
 if __name__ == '__main__':
45
-    train()
66
+    w = train()
67
+    print "__________"
68
+    b = train_b(w)
69
+    print "__________"
70
+    print w,b
71
+
72
+    train_data = read_data('train_data')
73
+    X, y = zip(*train_data)
74
+    X = np.array(X)
75
+    y = np.array(y)
76
+    model = LinearRegression()
77
+    # 一调用这个函数,就会不停地找合适的w和b 直到误差最小
78
+    model.fit(X, y)
79
+    print model.coef_, model.intercept_