生成自定义推荐数据集

举报
逸尘2022 发表于 2022/03/10 15:49:28 2022/03/10
【摘要】 ModelArts AI Gallery推荐预置算法(https://developer.huaweicloud.com/develop/aigallery/algorithm/detail?id=214dcb6c-9d58-40e2-b7f6-9091d22c8d36)提供了criteo部分数据集和ali-ccp部分数据集。本教程介绍如何生成自定义推荐数据集。包括标签,连续特征,离散特征,...

ModelArts AI Gallery推荐预置算法(https://developer.huaweicloud.com/develop/aigallery/algorithm/detail?id=214dcb6c-9d58-40e2-b7f6-9091d22c8d36)提供了criteo部分数据集和ali-ccp部分数据集。

本教程介绍如何生成自定义推荐数据集。包括标签,连续特征,离散特征,多值离散特征。


# Copyright 2022 ModelArts Authors from Huawei Cloud. All Rights Reserved.
# https://www.huaweicloud.com/product/modelarts.html
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import csv
import secrets

secrets_generator = secrets.SystemRandom()


class CTRData(object):
    def __init__(self, value_num=13, category_num=26, multi_category_num=0, max_sequence_len=3):
        """
        :param value_num: 连续特征数
        :param category_num: 离散特征数
        :param multi_category_num: 多值离散特征数
        :param max_sequence_len: 多值离散特征的最大长度
        """
        self.value_num = value_num
        self.cat_num = category_num
        self.multi_cat_num = multi_category_num
        self.max_seq_len = max_sequence_len

    def generate(self, item_id=None, user_id=None):
        """
        随机生成一行数据,包括标签,value_num个连续特征,category_num个离散特征,multi_category_num个多值离散特征
        :param item_id: 商品id
        :param user_id: 用户id
        """
        category_col_num = self.cat_num

        # category_num个离散特征包括item_id和user_id
        if item_id:
            category_col_num -= 1
        if user_id:
            category_col_num -= 1

        label = [str(secrets_generator.randint(0, 1))]
        val_feature = [str(secrets_generator.random()) if secrets_generator.randint(0, 10) else ""
                       for _ in range(self.value_num)]
        cat_feature = [str(secrets_generator.randint(1, 100)) if secrets_generator.randint(0, 10) else ""
                       for _ in range(category_col_num)]

        multi_category_len = [secrets_generator.randint(0, self.max_seq_len) for _ in range(self.multi_cat_num)]
        multi_cat_feature_list = [[secrets_generator.randint(1, 100) for _ in range(multi_category_len[j])] for j in
                                  range(self.multi_cat_num)]
        multi_cat_feature = ["|".join(list(map(str, x))) for x in multi_cat_feature_list]

        if item_id:
            cat_feature += [item_id]
        if user_id:
            cat_feature += [user_id]

        line = label + val_feature + cat_feature + multi_cat_feature
        return line


def mkdir(dir_path):
    dir_path = dir_path.strip()
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)


if __name__ == "__main__":
    #############################################################################################################
    # 特别注意:
    # 将value_num连续特征数,category_num离散特征数,multi_category_num多值离散特征数,max_sequence_len多值离散特征的最大长度
    # 商品的序号总数, 训练样本数, 测试样本数修改为实际值
    #############################################################################################################
    ctr_data = CTRData(value_num=16, category_num=4, multi_category_num=3, max_sequence_len=5)
    # 商品的序号总数
    item_num = 8 * (10 ** 2)
    # 训练样本数
    train_num = 80 * (10 ** 2)
    # 测试样本数
    test_num = 1000

    train_data_path = '../../ctr_data/train_data/train.csv'
    test_data_path = '../../ctr_data/test_data/test.csv'
    mkdir(os.path.dirname(train_data_path))
    mkdir(os.path.dirname(test_data_path))

    with open(train_data_path, 'w') as f1:
        f = csv.writer(f1)
        for i in range(train_num):
            line_data = ctr_data.generate(i % item_num)
            f.writerow(line_data)
    print("train data is done!")

    with open(test_data_path, 'w') as f2:
        f = csv.writer(f2)
        for i in range(test_num):
            test_item_id = secrets_generator.randint(0, item_num - 1)
            line_data = ctr_data.generate(test_item_id)
            f.writerow(line_data)
    print("test data is done!")

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。