random_forest.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # -*- encoding:utf-8 -*-
  2. from sklearn.tree import DecisionTreeClassifier
  3. from sklearn.ensemble import RandomForestClassifier
  4. from sklearn.ensemble import RandomForestRegressor
  5. from sklearn.datasets import load_wine
  6. from sklearn.model_selection import train_test_split
  7. import numpy as np
  8. from sklearn.tree import DecisionTreeRegressor
  9. from sklearn import metrics
  10. '''
  11. 参数 含义
  12. criterion 不纯度的衡量指标,有基尼系数和信息熵两种选择
  13. max_depth 树的最大深度,超过最大深度的树枝都会被剪掉
  14. min_samples_leaf 一个节点在分枝后的每个子节点都必须包含至少min_samples_leaf个训练样本,否则分枝就不会发生
  15. min_samples_split 一个节点必须要包含至少min_samples_split个训练样本,这个节点才允许被分枝,否则分枝就不会发生
  16. max_features max_features限制分枝时考虑的特征个数,超过限制个数的特征都会被舍弃,默认值为总特征个数开平方取整
  17. min_impurity_decrease 限制信息增益的大小,信息增益小于设定数值的分枝不会发生
  18. '''
  19. def demo_wine():
  20. wine = load_wine()
  21. #print wine.data
  22. #print wine.target
  23. Xtrain, Xtest, Ytrain, Ytest = train_test_split(wine.data,wine.target,test_size=0.3)
  24. rfc = RandomForestClassifier(random_state=0,n_estimators=10)
  25. rfc = rfc.fit(Xtrain,Ytrain)
  26. print(rfc.score(Xtest,Ytest))
  27. def read_data(path):
  28. with open(path) as f :
  29. lines=f.readlines()
  30. lines=[eval(line.strip()) for line in lines]
  31. X,y=zip(*lines)
  32. X=np.array(X)
  33. y=np.array(y)
  34. return X,y
  35. def demo():
  36. X_train, y_train = read_data("../bbztx/train_data")
  37. X_test, y_test = read_data("../bbztx/test_data")
  38. Xtrain, Xtest, Ytrain, Ytest = train_test_split(X_train, y_train, test_size=0.3)
  39. rfc = RandomForestRegressor(random_state=0, n_estimators=10, max_depth=10)
  40. rfc = rfc.fit(Xtrain, Ytrain)
  41. print(rfc.score(Xtest, Ytest))
  42. print(rfc.predict(X_test))
  43. # print(rfc.score(Xtest, y_test))
  44. y_pred_test = rfc.predict(X_test)
  45. test_mse = metrics.mean_squared_error(y_test, y_pred_test)
  46. print("测试集MSE:", test_mse)
  47. if __name__ == '__main__':
  48. demo()