#!/usr/bin/python # -*- coding: UTF-8 -*- ''' 最简单的mse ''' import sys import os sys.path.append(os.path.abspath('..')) import numpy as np from sklearn.tree import DecisionTreeRegressor from sklearn import metrics from util.config import config def read_data(path): with open(path) as f : lines=f.readlines() lines=[eval(line.strip()) for line in lines] X,y=zip(*lines) X=np.array(X) y=np.array(y) return X,y def demo(): X_train,y_train=read_data(config.get('application', 'train_data_path')) X_test,y_test=read_data(config.get('application', 'test_data_path')) dt1 = DecisionTreeRegressor(max_depth=10) dt1.fit(X_train, y_train) y_pred_train = dt1.predict(X_train) train_mse = metrics.mean_squared_error(y_train, y_pred_train) print("训练集MSE:", train_mse) y_pred_test = dt1.predict(X_test) print(y_pred_test) test_mse = metrics.mean_squared_error(y_test, y_pred_test) print("测试集MSE:", test_mse) if __name__ == '__main__': demo()