Stacking集成学习挑战天池新人赛【工业蒸汽量预测 】 (2) 基础类、交叉验证方法构建

举报
地上一只鹅~ 发表于 2018/12/23 15:06:49 2018/12/23
【摘要】 由于后续将使用sklearn库实现大部分的初级学习模型,这里将构建一个sklern基础类,方便代码的使用和拓展。基础类构建class SklearnHelper(object): def __init__(self, clf, seed=0, params=None): params['random_state'] = seed self.clf = clf...

由于后续将使用sklearn库实现大部分的初级学习模型,这里将构建一个sklern基础类,方便代码的使用和拓展。

基础类构建

# 预测结果以mean square error作为评判标准(均方差越小越好)
from sklearn.metrics import mean_squared_error

class SklearnHelper(object):
    def __init__(self, clf, seed=0, params=None):
        params['random_state'] = seed
        self.clf = clf(**params)

    def train(self, x_train, y_train, x_val, y_val):
        self.clf.fit(x_train, y_train)
        y_pre = self.predict(x_val)
        return mean_squared_error(y_val, y_pre)

    def fit(self, x_train, y_train):
        return self.clf.fit(x_train, y_train)
    
    def predict(self, x):
        return self.clf.predict(x)
    
    def feature_importances(self):
        print(self.clf.feature_importances_)

交叉验证方法构建

from sklearn.model_selection import KFold

def get_oof(clf, x_train, y_train, x_test, n_folds = 5):
    """K-fold stacking"""
    num_train, num_test = x_train.shape[0], x_test.shape[0]
    oof_train = np.zeros((num_train,)) 
    oof_test = np.zeros((num_test,))
    oof_test_all_fold = np.zeros((num_test, n_folds))
    scores = []
    KF = KFold(n_splits = n_folds, random_state=2017)
    for i, (train_index, val_index) in enumerate(KF.split(x_train)):
        print('{0} fold, train {1}, val {2}'.format(i, len(train_index), len(val_index)))
        x_tra, y_tra = x_train[train_index], y_train[train_index]
        x_val, y_val = x_train[val_index], y_train[val_index]
        score = clf.train(x_tra, y_tra, x_val, y_val)
        scores.append(score)
        oof_train[val_index] = clf.predict(x_val)
        oof_test_all_fold[:, i] = clf.predict(x_test)
    oof_test = np.mean(oof_test_all_fold, axis=1)
    print('all scores {0}, average {1}'.format(scores, np.mean(scores)))
    return oof_train, oof_test

下一篇介绍初级学习模型构建

【版权声明】本文为华为云社区用户原创内容,未经允许不得转载,如需转载请自行联系原作者进行授权。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。