train.py 1.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. #!/usr/bin/python
  2. # -*- coding: UTF-8 -*-
  3. '''
  4. 最简单的mse
  5. '''
  6. import sys
  7. import os
  8. sys.path.append(os.path.abspath('..'))
  9. import numpy as np
  10. from sklearn.tree import DecisionTreeRegressor
  11. from sklearn import metrics
  12. from util.config import config
  13. def read_data(path):
  14. with open(path) as f :
  15. lines=f.readlines()
  16. lines=[eval(line.strip()) for line in lines]
  17. X,y=zip(*lines)
  18. X=np.array(X)
  19. y=np.array(y)
  20. return X,y
  21. def demo():
  22. X_train,y_train=read_data(config.get('application', 'train_data_path'))
  23. X_test,y_test=read_data(config.get('application', 'test_data_path'))
  24. dt1 = DecisionTreeRegressor(max_depth=10)
  25. dt1.fit(X_train, y_train)
  26. y_pred_train = dt1.predict(X_train)
  27. train_mse = metrics.mean_squared_error(y_train, y_pred_train)
  28. print("训练集MSE:", train_mse)
  29. y_pred_test = dt1.predict(X_test)
  30. print(y_pred_test)
  31. test_mse = metrics.mean_squared_error(y_test, y_pred_test)
  32. print("测试集MSE:", test_mse)
  33. if __name__ == '__main__':
  34. demo()