您好,欢迎来到九壹网。
搜索
您的当前位置:首页深度学习模型融合stacking

深度学习模型融合stacking

来源:九壹网
深度学习模型融合stacking

当你的深度学习模型变得很多时,选⼀个确定的模型也是⼀个头痛的问题。或者你可以把他们都⽤起来,就进⾏模型融合。我主要使⽤stacking和blend⽅法。先把代码贴出来,⼤家可以看⼀下。

1 import numpy as np 2 import pandas as pd

3 import matplotlib.pyplot as plt

4 from sklearn.metrics import roc_curve 5

6 SEED = 222

7 np.random.seed(SEED)

8 from sklearn.model_selection import train_test_split 9

10 from sklearn.metrics import roc_auc_score 11 from sklearn.svm import SVC,LinearSVC

12 from sklearn.naive_bayes import GaussianNB

13 from sklearn.ensemble import RandomForestClassifier,GradientBoostingClassifier 14 from sklearn.linear_model import LogisticRegression 15 from sklearn.neighbors import KNeighborsClassifier 16 from sklearn.neural_network import MLPClassifier 17

18 df = pd.read_csv('input.csv') 19

20 def get_train_test(): # 数据处理 21

22 y = 1 * (df.cand_pty_affiliation == \"REP\") 23 x = df.drop(['cand_pty_affiliation'],axis=1) 24 x = pd.get_dummies(x,sparse=True)

25 x.drop(x.columns[x.std()==0],axis=1,inplace=True)

26 return train_test_split(x,y,test_size=0.95,random_state=SEED) 27

28 def get_models(): # 模型定义 29 nb = GaussianNB()

30 svc = SVC(C=100,probability=True)

31 knn = KNeighborsClassifier(n_neighbors=3)

32 lr = LogisticRegression(C=100,random_state=SEED)

33 nn = MLPClassifier((80, 10), early_stopping=False, random_state=SEED) 34 gb = GradientBoostingClassifier(n_estimators =100, random_state = SEED)

35 rf = RandomForestClassifier(n_estimators=1,max_depth=3,random_state=SEED) 36

37 models = {'svm':svc, 38 'knn':knn,

39 'naive bayes':nb, 40 'mlp-nn':nn,

41 'random forest':rf, 42 'gbm':gb, 43 'logistic':lr, 44 }

45 return models 46

47 def train_base_learnres(base_learners,inp,out,verbose=True): # 训练基本模型 48 if verbose:print(\"fitting models.\")

49 for i,(name,m) in enumerate(base_learners.items()): 50 if verbose:print(\"%s...\" % name,end=\" \",flush=False) 51 m.fit(inp,out)

52 if verbose:print(\"done\") 53

54 def predict_base_learners(pred_base_learners,inp,verbose=True): # 把基本学习器的输出作为融合学习的特征,这⾥计算特征 55 p = np.zeros((inp.shape[0],len(pred_base_learners))) 56 if verbose:print(\"Generating base learner predictions.\") 57 for i,(name,m) in enumerate(pred_base_learners.items()): 58 if verbose:print(\"%s...\" % name,end=\" \",flush=False) 59 p_ = m.predict_proba(inp) 60 p[:,i] = p_[:,1]

61 if verbose:print(\"done\") 62 return p 63

def ensemble_predict(base_learners,meta_learner,inp,verbose=True): # 融合学习进⾏预测

65 p_pred = predict_base_learners(base_learners,inp,verbose=verbose) # 测试数据必须先经过基本学习器计算特征 66 return p_pred,meta_learner.predict_proba(p_pred)[:,1] 67

68 def ensenmble_by_blend(): # blend融合

69 xtrain_base, xpred_base, ytrain_base, ypred_base = train_test_split( 70 xtrain, ytrain, test_size=0.5, random_state=SEED 71 ) # 把数据切分成两部分 72

73 train_base_learnres(base_learners, xtrain_base, ytrain_base) # 训练基本模型 74

75 p_base = predict_base_learners(base_learners, xpred_base) # 把基本学习器的输出作为融合学习的特征,这⾥计算特征 76 meta_learner.fit(p_base, ypred_base) # 融合学习器的训练

77 p_pred, p = ensemble_predict(base_learners, meta_learner, xtest) # 融合学习进⾏预测

78 print(\"\\nEnsemble ROC-AUC score: %.3f\" % roc_auc_score(ytest, p)) 79 80

81 from sklearn.base import clone

82 def stacking(base_learners,meta_learner,X,y,generator): # stacking进⾏融合 83 print(\"Fitting final base learners...\",end=\"\")

84 train_base_learnres(base_learners,X,y,verbose=False) 85 print(\"done\") 86

87 print(\"Generating cross-validated predictions...\") 88 cv_preds,cv_y = [],[]

for i,(train_inx,test_idx) in enumerate(generator.split(X)): 90 fold_xtrain,fold_ytrain = X[train_inx,:],y[train_inx] 91 fold_xtest,fold_ytest = X[test_idx,:],y[test_idx] 92

93 fold_base_learners = {name:clone(model)

94 for name,model in base_learners.items()}

95 train_base_learnres(fold_base_learners,fold_xtrain,fold_ytrain,verbose=False)

96 fold_P_base = predict_base_learners(fold_base_learners,fold_xtest,verbose=False) 97

98 cv_preds.append(fold_P_base) 99 cv_y.append(fold_ytest)100

101 print(\"Fold %i done\" %(i+1))102 print(\"CV-predictions done\")

103 cv_preds = np.vstack(cv_preds)104 cv_y = np.hstack(cv_y)105

106 print(\"Fitting meta learner...\",end=\"\")107 meta_learner.fit(cv_preds,cv_y)108 print(\"done\")109

110 return base_learners,meta_learner111

112 def ensemble_by_stack():

113 from sklearn.model_selection import KFold114 cv_base_learners,cv_meta_learner = stacking(

115 get_models(),clone(meta_learner),xtrain.values,ytrain.values,KFold(2))

116 P_pred,p = ensemble_predict(cv_base_learners,cv_meta_learner,xtest,verbose=False)117 print(\"\\nEnsemble ROC-AUC score: %.3f\" %roc_auc_score(ytest,p))118

119 def plot_roc_curve(ytest,p_base_learners,p_ensemble,labels,ens_label):120 plt.figure(figsize=(10,8))121 plt.plot([0,1],[0,1],'k--')122 cm = [plt.cm.rainbow(i)

123 for i in np.linspace(0,1.0, p_base_learners.shape[1] +1)]124 for i in range(p_base_learners.shape[1]):125 p = p_base_learners[:,i]

126 fpr,tpr,_ = roc_curve(ytest,p)

127 plt.plot(fpr,tpr,label = labels[i],c=cm[i+1])128 fpr, tpr, _ = roc_curve(ytest, p_ensemble)129 plt.plot(fpr, tpr, label=ens_label, c=cm[0])130 plt.xlabel('False positive rate')131 plt.ylabel('True positive rate')132 plt.title('ROC curve')

133 plt.legend(frameon=False)134 plt.show()135

136 from mlens.ensemble import SuperLearner137 def use_pack():

138 sl =SuperLearner(

139 folds=10,random_state=SEED,verbose=2,140 # backend=\"multiprocessing\"141 )

142 # Add the base learners and the meta learner143 sl.add(list(base_learners.values()),proba=True)144 sl.add_meta(meta_learner,proba=True)145 # Train the ensemble146 sl.fit(xtrain,ytrain)147 # Predict the test set

148 p_sl=sl.predict_proba(xtest)149

150 print(\"\\nSuper Learner ROC-AUC score: %.3f\" % roc_auc_score(ytest,p_sl[:,1]))151

152 if __name__ == \"__main__\":

153 xtrain, xtest, ytrain, ytest = get_train_test()154 base_learners = get_models()155

156 meta_learner = GradientBoostingClassifier(157 n_estimators=1000,158 loss=\"exponential\",159 max_depth=4,160 subsample=0.5,161 learning_rate=0.005,

162 random_state=SEED163 )1

165 # ensenmble_by_blend() # blend进⾏融合166 # ensemble_by_stack() # stack进⾏融合167 use_pack() # 调⽤包进⾏融合

因篇幅问题不能全部显示,请点此查看更多更全内容

Copyright © 2019- 91gzw.com 版权所有 湘ICP备2023023988号-2

违法及侵权请联系:TEL:199 18 7713 E-MAIL:2724546146@qq.com

本站由北京市万商天勤律师事务所王兴未律师提供法律服务