Browse Source

把测试数据和预测线画在一块

yufeng0528 4 years ago
parent
commit
d8e33d7d80
2 changed files with 34 additions and 3 deletions
  1. 26 1
      draw/draw_util.py
  2. 8 2
      linear/train.py

+ 26 - 1
draw/draw_util.py

@@ -1,6 +1,7 @@
1 1
 #!/usr/bin/python
2 2
 # -*- coding: UTF-8 -*-
3 3
 import matplotlib.pyplot as plt
4
+import numpy as np
4 5
 
5 6
 #绘制散点图
6 7
 def drawScatter(heights,weights):
@@ -14,7 +15,31 @@ def drawScatter(heights,weights):
14 15
     plt.show()
15 16
 
16 17
 
18
+def drawLine(w, b):
19
+    plt.xlabel("x")
20
+    plt.ylabel("y")
21
+    x = np.arange(0, 10)
22
+    y = w*x + b
23
+
24
+    plt.plot(x, y)
25
+    plt.show()
26
+
27
+
28
+def drawScatterAndLine(p, q, w, b):
29
+    plt.scatter(p, q)
30
+    plt.xlabel('p')
31
+    plt.ylabel('q')
32
+    plt.title('line regesion')
33
+
34
+    x = np.arange(0, 11)
35
+    y = w * x + b
36
+
37
+    plt.plot(x, y, color='red')
38
+    plt.show()
39
+
40
+
17 41
 if __name__ == '__main__':
18 42
     heights = [1.5, 1.7]
19 43
     weights = [43, 61]
20
-    drawScatter(heights,weights)
44
+    drawScatter(heights,weights)
45
+    drawLine(1,2)

+ 8 - 2
linear/train.py

@@ -68,6 +68,7 @@ def test():
68 68
 	with open("test_curve.csv","w") as f :
69 69
 		f.writelines("\n".join(test_curve))
70 70
 	'''
71
+	return X_train,y_train, model.coef_, model.intercept_
71 72
 
72 73
 
73 74
 def draw_line():
@@ -78,7 +79,12 @@ def draw_line():
78 79
 
79 80
 
80 81
 if __name__ == '__main__':
81
-	draw_line()
82
-	test()
82
+	# draw_line()
83
+	p, q, w,b = test()
84
+	p = [i[0] for i in p.tolist()]
85
+	q = [i[0] for i in q.tolist()]
86
+	w = w[0]
87
+	b = b[0]
88
+	draw_util.drawScatterAndLine(p, q, w, b)
83 89
 
84 90