123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- #!/usr/bin/python
- # -*- coding: UTF-8 -*-
- import sys
- reload(sys)
- sys.setdefaultencoding('utf-8')
- import random
- import math
- import matplotlib.pyplot as plt
- import numpy as np
- def read_data(path):
- with open(path) as f :
- lines=f.readlines()
- lines=[eval(line.strip()) for line in lines]
- def cal_mse(data,w,b=3):
- y2=[ w*x[0]+b for [x,_] in data]
- y=[s[1][0] for s in data]
- mse=sum([(s1-s2)*(s1-s2) for [s1,s2] in zip(y,y2)])/len(data)
- return mse
- with open("train_data") as f :
- lines=f.readlines()
- data=[eval(line.strip()) for line in lines]
- sub_data1=random.sample(data,10)
- sub_data2=random.sample(data,10)
- sub_data3=random.sample(data,50)
- sub_data4=random.sample(data,100)
- # results=["w,all_sample,one_sample1,one_sample2,50_sample,100_sampe"]
- results = []
- x= []
- y= []
- def drawLine(x, y):
- plt.xlabel("w")
- plt.ylabel("mse")
- plt.plot(x, [i[0] for i in y], color="black")
- plt.plot(x, [i[1] for i in y], color="red")
- plt.plot(x, [i[2] for i in y], color="green")
- plt.plot(x, [i[3] for i in y], color="blue")
- plt.plot(x, [i[4] for i in y], color="yellow")
- plt.show()
- '''
- 计算w
- '''
- for i in range(-2000,2000):
- w=5 + 1.0*i/500
- mse=cal_mse(data,w)
- sub_mse1=cal_mse(sub_data1,w)
- sub_mse2=cal_mse(sub_data2,w)
- sub_mse3=cal_mse(sub_data3,w)
- sub_mse4=cal_mse(sub_data4,w)
- results.append("{},{},{},{},{},{}".format(w,mse,sub_mse1,sub_mse2,sub_mse3,sub_mse4))
- # print "{},{},{},{},{},{}".format(w, mse, sub_mse1,sub_mse2,sub_mse3,sub_mse4)
- x.append(w)
- y.append((mse, sub_mse1,sub_mse2,sub_mse3,sub_mse4))
- drawLine(x, y)
- x = []
- y = []
- '''
- 计算b
- '''
- for i in range(-2000,2000):
- b=1 + 1.0*i/500
- mse=cal_mse(data,5, b)
- sub_mse1=cal_mse(sub_data1,5, b)
- sub_mse2=cal_mse(sub_data2,5, b)
- sub_mse3=cal_mse(sub_data3,5, b)
- sub_mse4=cal_mse(sub_data4,5, b)
- results.append("{},{},{},{},{},{}".format(w,mse,sub_mse1,sub_mse2,sub_mse3,sub_mse4))
- # print "{},{},{},{},{},{}".format(w, mse, sub_mse1,sub_mse2,sub_mse3,sub_mse4)
- x.append(b)
- y.append((mse, sub_mse1,sub_mse2,sub_mse3,sub_mse4))
- drawLine(x, y)
- '''
- with open("mse_curve.csv","w") as f :
- f.writelines("\n".join(results))
- '''
|