random_forest.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # -*- encoding:utf-8 -*-
  2. from sklearn.tree import DecisionTreeClassifier
  3. from sklearn.ensemble import RandomForestClassifier
  4. from sklearn.ensemble import RandomForestRegressor
  5. from sklearn.datasets import load_wine
  6. from sklearn.model_selection import train_test_split
  7. import numpy as np
  8. from util.config import config
  9. from sklearn.tree import DecisionTreeRegressor
  10. from sklearn import metrics
  11. '''
  12. 参数 含义
  13. criterion 不纯度的衡量指标,有基尼系数和信息熵两种选择
  14. max_depth 树的最大深度,超过最大深度的树枝都会被剪掉
  15. min_samples_leaf 一个节点在分枝后的每个子节点都必须包含至少min_samples_leaf个训练样本,否则分枝就不会发生
  16. min_samples_split 一个节点必须要包含至少min_samples_split个训练样本,这个节点才允许被分枝,否则分枝就不会发生
  17. max_features max_features限制分枝时考虑的特征个数,超过限制个数的特征都会被舍弃,默认值为总特征个数开平方取整
  18. min_impurity_decrease 限制信息增益的大小,信息增益小于设定数值的分枝不会发生
  19. '''
  20. def demo_wine():
  21. wine = load_wine()
  22. #print wine.data
  23. #print wine.target
  24. Xtrain, Xtest, Ytrain, Ytest = train_test_split(wine.data,wine.target,test_size=0.3)
  25. rfc = RandomForestClassifier(random_state=0,n_estimators=10)
  26. rfc = rfc.fit(Xtrain,Ytrain)
  27. print(rfc.score(Xtest,Ytest))
  28. def read_data(path):
  29. with open(path) as f :
  30. lines=f.readlines()
  31. lines=[eval(line.strip()) for line in lines]
  32. X,y=zip(*lines)
  33. X=np.array(X)
  34. y=np.array(y)
  35. return X,y
  36. def demo():
  37. X_train, y_train = read_data(config.get('application', 'train_data_path'))
  38. X_test, y_test = read_data(config.get('application', 'train_data_path'))
  39. Xtrain, Xtest, Ytrain, Ytest = train_test_split(X_train, y_train, test_size=0.3)
  40. rfc = RandomForestRegressor(random_state=0, n_estimators=10, max_depth=10)
  41. rfc = rfc.fit(Xtrain, Ytrain)
  42. print(rfc.score(Xtest, Ytest))
  43. print(rfc.predict(X_test))
  44. # print(rfc.score(Xtest, y_test))
  45. y_pred_test = rfc.predict(X_test)
  46. test_mse = metrics.mean_squared_error(y_test, y_pred_test)
  47. print("测试集MSE:", test_mse)
  48. if __name__ == '__main__':
  49. demo()