【项目实战】基于 MobileNetV3 实现恶意文件静态检测(下)

举报
sidiot 发表于 2023/10/20 23:25:20 2023/10/20
【摘要】 在上篇博文中,博主介绍了关于 MobileNetV3 的网络结构以及主体代码实现;接下来,博主将介绍模型的训练,验证评估以及接口设计。

前言

上篇博文中,博主介绍了关于 MobileNetV3 的网络结构以及主体代码实现;接下来,博主将介绍模型的训练,验证评估以及接口设计

最终一个直观的页面展示如下: image.png

页面是用 ChatGPT 简单生成的,看着比较简陋,请不要在意,重点还是模型的实现!

模型训练

在完成模型结构设计之后,接下来就是对模型进行训练了,通常模型对于输入都是有一定要求的,因此在训练之前,需要对数据进行相关处理,以确保能够被模型接收。

这里的话,由于样本文件的大小不一,同时也为了能够高效的检测出样本中的恶意部分,所以将样本切割成一个个  的图像块,对于不够1024的部分,使用0进行填充,代码如下所示:

def pltexe(self, arr):
    arr_n = len(arr) // (1024 * 1024)
    arr_end_len = len(arr) % (1024 * 1024)
    re_arr = []
    siz = 1024

    for ite in range(arr_n):
        st = ite * 1024 * 1024
        pggg0 = np.array(arr[st:st+1024*1024])
        re_arr.append(pggg0.reshape(siz, siz) / 255)

    if arr_end_len != 0:
        arr_ = (1024 * 1024 - arr_end_len) * [0]
        pggg0 = np.array(arr[1024*1024*arr_n:] + arr_)
        re_arr.append(pggg0.reshape(siz, siz) / 255)

    return re_arr

优化器,损失函数等自己根据需要进行设置,这里仅作参考:

optimizer = torch.optim.AdamW(model.parameters(), lr=0.002, betas=(
    0.9, 0.999), eps=1e-05, weight_decay=4e-05, amsgrad=True)
    
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=[10, 50], gamma=0.1)
    
criterion = nn.CrossEntropyLoss(weight=torch.tensor([1., 0.2])).to(device)

train_loader = DataLoader(train_data, batch_size=20, shuffle=True,
                          num_workers=20, collate_fn=PadSequence(maxlen=0))

接下来就是模型训练了,将图像块输入到模型,获得预测结果与实际标签进行比对计算 ,通过反向传播来调整模型参数:

for iter_count, batch_data in enumerate(train_loader):
    test_x = batch_data[0].to(torch.float32).to(device)
    out = model(test_x)
    label = batch_data[1].to(device)
    train_size += label.size(0)

    loss = criterion(out, label.long())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    running_loss += loss.item() * label.size(0)
    preds = out.argmax(dim=1).cpu().detach().numpy()
    labels = label.cpu().detach().numpy()
    correct_count = int((labels == preds).sum())
    running_corrects += correct_count

训练过程如下所示:

1677489229013_4E7E2BEE-3A66-4a5b-B524-92685EFB5589.png

验证评估

对于完成训练的模型,我们需要通过一定的指标来评估一下模型的好坏,这里博主采用的是混淆矩阵。

混淆矩阵的每一列代表了预测类别,每一列的总数表示预测为该类别的数据的数目;每一行代表了数据的真实归属类别,每一行的数据总数表示该类别的数据实例的数目。

这里用的是2*2的混淆矩阵,四个指标分别为 TP、FP、TN、FN,其表示的意义为:

  1. TP (True Positive) 能够检测到正例,即预测和实际都为 P;
  2. FP (False Positive) 错误的正例,误将负例检测为正例,即预测为 P,实际为 N;
  3. TN (True Negative) 能够检测到负例,即预测和实际都为 N;
  4. FN (False Negative) 错误的负例,误将正例检测为负例,即预测为 N,实际为 P;

在获得 TP、FP、TN、FN 的值后,就可以计算出精确率(Accuracy)、准确率(Precision)、召回率(Recall),其表示的意义与公式如下:

  1. 精确率:表示模型识别正确的样本个数占总样本数的比例。

  2. 准确率:表示在模型识别为正类的样本中,正确的样本个数占总样本数的比例。

  3. 召回率:表示模型识别正确的正类样本个数占总的正类样本个数的比例。

相关代码如下所示:

preds_sg = out.argmax(dim=1).cpu().detach().numpy()
label_sg = label.cpu().detach().numpy()

preds_np = preds_sg
label_np = label_sg.reshape(-1)
train_correct01 = int(((preds_np == zes) & (label_np == ons)).sum())
train_correct10 = int(((preds_np == ons) & (label_np == zes)).sum())
train_correct11 = int(((preds_np == ons) & (label_np == ons)).sum())
train_correct00 = int(((preds_np == zes) & (label_np == zes)).sum())
FN += train_correct01
FP += train_correct10
TP += train_correct11
TN += train_correct00

accuracy = (TP+TN) / (TP+TN+FP+FN)
precision = TP / (TP+FP)
recall = TP / (TP+FN)

评估日志如下所示:

image.png

从数据上来看,模型的训练过程还是很健康的,也可以画图进行一个直观的展示:

image.png

接口设计

现在我们需要将模型部署上线,这里就做一个简单的接口设计,假设我们的业务需求是用户上传一个文件,通过模型的判断,返回结果告诉用户是不是恶意文件。

这里只要将模型的验证阶段稍作修改即可,伪代码如下所示:

def verify(file):
    import mobilenetv3

    pad = PadSequence()
    model = mobilenetv3(mode='small')

    # 模型的加载
    ...

    featurelist = []
    try:
        re_arr = pad.pltexe(pad.get_mnemonic_list(file))    
        for pgg0 in re_arr:
            featurelist.append(torch.tensor(pgg0))
        featurelist_batch = torch.stack(featurelist, dim=0)
        featurelist_batch = torch.stack((featurelist_batch,)*3, axis=1)
        print("data processed.")

        test_x = featurelist_batch.to(torch.float32).to(device)
        out = model(test_x)
        pred = out.argmax(dim=1).cpu().detach().numpy()
        print(pred, out)
        print("verification ended.")
        return {'status': 'success', 'pred': int(pred), 'out': out[0].tolist()}

    except Exception as e:
        print(e)
        return {'status': 'fail'}

上述代码将用户传入的文件进行处理,然后输入到模型中,对于模型返回的预测结果进行格式化后再进行返回。

image.png

我们再设计一个简单的前端页面,这里用现在爆火的 ChatGPT 来完成,让其先设计一个有文件上传按钮的前端页面,然后再对这个页面进行美化。

image.png

将部分内容略作修改,简单的前端页面就做好了。

最后通过 Flask 框架设计一个接口就可以了:

@app.route("/verify", methods=['GET', 'POST'])
def getVerify():
    fileStorage = request.get_data()
    from model import verify
    res = verify(fileStorage)
    print('res:',res)
    return res

后记

本文到此就结束了,文章细致的讲解了恶意文件静态检测模型的训练,验证评估以及接口设计。

以上就是 【项目实战】基于 MobileNetV3 实现恶意文件静态检测(下) 的全部内容了,希望本篇博文对大家有所帮助!

💖 我是 𝓼𝓲𝓭𝓲𝓸𝓽,期待你的关注,创作不易,请多多支持;

👍 公众号:sidiot的技术驿站

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。