123456789101112131415161718192021222324252627282930313233343536373839404142 |
- #!/usr/bin/python
- # -*- coding: UTF-8 -*-
- '''
- 最简单的mse
- '''
- import sys
- import os
- sys.path.append(os.path.abspath('..'))
- import numpy as np
- from sklearn.tree import DecisionTreeRegressor
- from sklearn import metrics
- from util.config import config
- def read_data(path):
- with open(path) as f :
- lines=f.readlines()
- lines=[eval(line.strip()) for line in lines]
- X,y=zip(*lines)
- X=np.array(X)
- y=np.array(y)
- return X,y
- def demo():
- X_train,y_train=read_data(config.get('application', 'train_data_path'))
- X_test,y_test=read_data(config.get('application', 'test_data_path'))
- dt1 = DecisionTreeRegressor(max_depth=10)
- dt1.fit(X_train, y_train)
- y_pred_train = dt1.predict(X_train)
- train_mse = metrics.mean_squared_error(y_train, y_pred_train)
- print("训练集MSE:", train_mse)
- y_pred_test = dt1.predict(X_test)
- print(y_pred_test)
- test_mse = metrics.mean_squared_error(y_test, y_pred_test)
- print("测试集MSE:", test_mse)
- if __name__ == '__main__':
- demo()
|