#!/usr/bin/python # -*- coding: UTF-8 -*- ''' 最简单的mse ''' import sys import os sys.path.append(os.path.abspath('..')) import numpy as np from sklearn.tree import DecisionTreeRegressor from sklearn import metrics def read_data(path): with open(path) as f : lines=f.readlines() lines=[eval(line.strip()) for line in lines] X,y=zip(*lines) X=np.array(X) y=np.array(y) return X,y def demo(): X_train,y_train=read_data("../bbztx/train_data") X_test,y_test=read_data("../bbztx/test_data") dt1 = DecisionTreeRegressor(max_depth=10) dt1.fit(X_train, y_train) y_pred_train = dt1.predict(X_train) train_mse = metrics.mean_squared_error(y_train, y_pred_train) print("训练集MSE:", train_mse) y_pred_test = dt1.predict(X_test) print(y_pred_test) test_mse = metrics.mean_squared_error(y_test, y_pred_test) print("测试集MSE:", test_mse) if __name__ == '__main__': demo()