#-*- coding:utf-8 -*- import numpy as np from sklearn.ensemble import GradientBoostingClassifier,GradientBoostingRegressor from sklearn.model_selection import train_test_split from sklearn.datasets import load_wine,load_boston,load_breast_cancer from sklearn import tree def read_data(): boston = load_breast_cancer() Xtrain, Xtest, Ytrain, Ytest = train_test_split(boston.data, boston.target, test_size=0.3) for i in range(len(Ytrain)): if Ytrain[i] == 0: Ytrain[i] = -1 for i in range(len(Ytest)): if Ytest[i] == 0: Ytest[i] = -1 return Xtrain, Xtest, Ytrain, Ytest def init(Ytrain): positive = sum(Ytrain == 1) negative = Ytrain.shape[0] - positive p = np.log2(positive/negative) # 可能是为了训练稍微快点 return np.ones(Ytrain.shape[0])*p def fit(Xtrain, Ytrain): print("init", Ytrain[:10]) fx = [] clf_tress = [] fx0 = init(Ytrain) fx.append(fx0) print("0", fx0[:10]) gx = fx0 for i in range(10): # 求伪残差 hx_0 = [] for j in range(Ytrain.shape[0]): p = Ytrain[j] / (np.exp2(Ytrain[j]*gx[j]) + 1) hx_0.append(p) print("第", i, '轮 残差', gx[:10]) clf = tree.DecisionTreeRegressor(criterion="mse", max_features=1, max_depth=1) clf.fit(Xtrain, np.array(hx_0)) clf_tress.append(clf) fx_i = clf.predict(Xtrain)*0.7 print("第", i, '轮 结果', fx_i[:10]) fx.append(fx_i) gx = gx + fx_i gx = np.zeros(Ytrain.shape[0]) for i in range(len(fx)): gx = gx + fx[i] print(gx[:10]) gx = np.sign(gx) p = sum(gx==Ytrain)/Ytrain.shape[0] print("准确率", p) return clf_tress, fx0[0] def score(Xtest, Ytest, trees, fx0): gx = np.ones(Ytest.shape[0])*fx0 for i in range(len(trees)): gx = gx + trees[i].predict(Xtest) gx = np.sign(gx) p = sum(gx == Ytest) / Ytest.shape[0] print("准确率", p) gx = np.sign(trees[0].predict(Xtest)) p = sum(gx == Ytest) / Ytest.shape[0] print("准确率0", p) if __name__ == '__main__': Xtrain, Xtest, Ytrain, Ytest = read_data() trees,fx0 = fit(Xtrain, Ytrain) score(Xtest, Ytest, trees, fx0) gbm1 = GradientBoostingClassifier(n_estimators=10, max_depth=1, learning_rate=0.7, max_features='sqrt', random_state=10) gbm1.fit(Xtrain, Ytrain) print("gbdt", gbm1.score(Xtest, Ytest))