train.py 955 B

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. #!/usr/bin/python
  2. # -*- coding: UTF-8 -*-
  3. '''
  4. 最简单的mse
  5. '''
  6. import sys
  7. import os
  8. sys.path.append(os.path.abspath('..'))
  9. import numpy as np
  10. from sklearn.tree import DecisionTreeRegressor
  11. from sklearn import metrics
  12. def read_data(path):
  13. with open(path) as f :
  14. lines=f.readlines()
  15. lines=[eval(line.strip()) for line in lines]
  16. X,y=zip(*lines)
  17. X=np.array(X)
  18. y=np.array(y)
  19. return X,y
  20. def demo():
  21. X_train,y_train=read_data("../bbztx/train_data")
  22. X_test,y_test=read_data("../bbztx/test_data")
  23. dt1 = DecisionTreeRegressor(max_depth=10)
  24. dt1.fit(X_train, y_train)
  25. y_pred_train = dt1.predict(X_train)
  26. train_mse = metrics.mean_squared_error(y_train, y_pred_train)
  27. print("训练集MSE:", train_mse)
  28. y_pred_test = dt1.predict(X_test)
  29. print(y_pred_test)
  30. test_mse = metrics.mean_squared_error(y_test, y_pred_test)
  31. print("测试集MSE:", test_mse)
  32. if __name__ == '__main__':
  33. demo()