train.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # -*- encoding:utf-8 -*-
  2. from sklearn import datasets
  3. from sklearn.model_selection import train_test_split
  4. from sklearn.linear_model import LogisticRegression
  5. from sklearn.model_selection import cross_val_predict
  6. from numpy import shape
  7. from sklearn import metrics
  8. from sklearn.metrics import log_loss
  9. import numpy as np
  10. import matplotlib.pyplot as plt
  11. def read_data(path):
  12. with open(path) as f:
  13. lines = f.readlines()
  14. lines = [eval(line.strip()) for line in lines]
  15. X, y = zip(*lines)
  16. X = np.array(X)
  17. y = np.array(y)
  18. return X, y
  19. def curve(x_train, w, w0):
  20. results = x_train.tolist()
  21. for i in range(0, 100):
  22. x1 = 1.0 * i / 10
  23. x2 = -1 * (w[0] * x1 + w0) / w[1]
  24. results.append([x1, x2])
  25. results = ["{},{}".format(x1, x2) for [x1, x2] in results]
  26. return results
  27. def drawScatterAndLine(p, q):
  28. x1 = []
  29. x2 = []
  30. y1 = []
  31. y2 = []
  32. for idx,i in enumerate(q):
  33. if i == 0:
  34. x1.append(p[idx][0])
  35. y1.append(p[idx][1])
  36. else:
  37. x2.append(p[idx][0])
  38. y2.append(p[idx][1])
  39. plt.scatter(x1, y1)
  40. plt.scatter(x2, y2)
  41. plt.xlabel('p')
  42. plt.ylabel('q')
  43. plt.title('line regesion')
  44. plt.show()
  45. def main():
  46. X_train, y_train = read_data("train_data")
  47. drawScatterAndLine(X_train, y_train)
  48. X_test, y_test = read_data("test_data")
  49. model = LogisticRegression()
  50. model.fit(X_train, y_train)
  51. print("w", model.coef_)
  52. print("w0", model.intercept_)
  53. y_pred = model.predict(X_test)
  54. print(y_pred)
  55. # y_pred = model.predict_proba(X_test)
  56. # print y_pred
  57. # loss=log_loss(y_test,y_pred)
  58. # print "KL_loss:",loss
  59. # loss=log_loss(y_pred,y_test)
  60. # print "KL_loss:",loss
  61. '''
  62. curve_results=curve(X_train,model.coef_.tolist()[0],model.intercept_.tolist()[0])
  63. with open("train_with_splitline","w") as f :
  64. f.writelines("\n".join(curve_results))
  65. '''
  66. if __name__ == '__main__':
  67. main()