【项目实战】基于 MobileNetV3 实现恶意文件静态检测(下)
前言
在上篇博文中,博主介绍了关于 MobileNetV3 的网络结构以及主体代码实现;接下来,博主将介绍模型的训练,验证评估以及接口设计。
最终一个直观的页面展示如下:
页面是用 ChatGPT 简单生成的,看着比较简陋,请不要在意,重点还是模型的实现!
模型训练
在完成模型结构设计之后,接下来就是对模型进行训练了,通常模型对于输入都是有一定要求的,因此在训练之前,需要对数据进行相关处理,以确保能够被模型接收。
这里的话,由于样本文件的大小不一,同时也为了能够高效的检测出样本中的恶意部分,所以将样本切割成一个个 的图像块,对于不够1024的部分,使用0进行填充,代码如下所示:
def pltexe(self, arr):
arr_n = len(arr) // (1024 * 1024)
arr_end_len = len(arr) % (1024 * 1024)
re_arr = []
siz = 1024
for ite in range(arr_n):
st = ite * 1024 * 1024
pggg0 = np.array(arr[st:st+1024*1024])
re_arr.append(pggg0.reshape(siz, siz) / 255)
if arr_end_len != 0:
arr_ = (1024 * 1024 - arr_end_len) * [0]
pggg0 = np.array(arr[1024*1024*arr_n:] + arr_)
re_arr.append(pggg0.reshape(siz, siz) / 255)
return re_arr
优化器,损失函数等自己根据需要进行设置,这里仅作参考:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.002, betas=(
0.9, 0.999), eps=1e-05, weight_decay=4e-05, amsgrad=True)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[10, 50], gamma=0.1)
criterion = nn.CrossEntropyLoss(weight=torch.tensor([1., 0.2])).to(device)
train_loader = DataLoader(train_data, batch_size=20, shuffle=True,
num_workers=20, collate_fn=PadSequence(maxlen=0))
接下来就是模型训练了,将图像块输入到模型,获得预测结果与实际标签进行比对计算 ,通过反向传播来调整模型参数:
for iter_count, batch_data in enumerate(train_loader):
test_x = batch_data[0].to(torch.float32).to(device)
out = model(test_x)
label = batch_data[1].to(device)
train_size += label.size(0)
loss = criterion(out, label.long())
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item() * label.size(0)
preds = out.argmax(dim=1).cpu().detach().numpy()
labels = label.cpu().detach().numpy()
correct_count = int((labels == preds).sum())
running_corrects += correct_count
训练过程如下所示:
验证评估
对于完成训练的模型,我们需要通过一定的指标来评估一下模型的好坏,这里博主采用的是混淆矩阵。
混淆矩阵的每一列代表了预测类别,每一列的总数表示预测为该类别的数据的数目;每一行代表了数据的真实归属类别,每一行的数据总数表示该类别的数据实例的数目。
这里用的是2*2的混淆矩阵,四个指标分别为 TP、FP、TN、FN,其表示的意义为:
- TP (True Positive) 能够检测到正例,即预测和实际都为 P;
- FP (False Positive) 错误的正例,误将负例检测为正例,即预测为 P,实际为 N;
- TN (True Negative) 能够检测到负例,即预测和实际都为 N;
- FN (False Negative) 错误的负例,误将正例检测为负例,即预测为 N,实际为 P;
在获得 TP、FP、TN、FN 的值后,就可以计算出精确率(Accuracy)、准确率(Precision)、召回率(Recall),其表示的意义与公式如下:
精确率:表示模型识别正确的样本个数占总样本数的比例。
准确率:表示在模型识别为正类的样本中,正确的样本个数占总样本数的比例。
召回率:表示模型识别正确的正类样本个数占总的正类样本个数的比例。
相关代码如下所示:
preds_sg = out.argmax(dim=1).cpu().detach().numpy()
label_sg = label.cpu().detach().numpy()
preds_np = preds_sg
label_np = label_sg.reshape(-1)
train_correct01 = int(((preds_np == zes) & (label_np == ons)).sum())
train_correct10 = int(((preds_np == ons) & (label_np == zes)).sum())
train_correct11 = int(((preds_np == ons) & (label_np == ons)).sum())
train_correct00 = int(((preds_np == zes) & (label_np == zes)).sum())
FN += train_correct01
FP += train_correct10
TP += train_correct11
TN += train_correct00
accuracy = (TP+TN) / (TP+TN+FP+FN)
precision = TP / (TP+FP)
recall = TP / (TP+FN)
评估日志如下所示:
从数据上来看,模型的训练过程还是很健康的,也可以画图进行一个直观的展示:
接口设计
现在我们需要将模型部署上线,这里就做一个简单的接口设计,假设我们的业务需求是用户上传一个文件,通过模型的判断,返回结果告诉用户是不是恶意文件。
这里只要将模型的验证阶段稍作修改即可,伪代码如下所示:
def verify(file):
import mobilenetv3
pad = PadSequence()
model = mobilenetv3(mode='small')
# 模型的加载
...
featurelist = []
try:
re_arr = pad.pltexe(pad.get_mnemonic_list(file))
for pgg0 in re_arr:
featurelist.append(torch.tensor(pgg0))
featurelist_batch = torch.stack(featurelist, dim=0)
featurelist_batch = torch.stack((featurelist_batch,)*3, axis=1)
print("data processed.")
test_x = featurelist_batch.to(torch.float32).to(device)
out = model(test_x)
pred = out.argmax(dim=1).cpu().detach().numpy()
print(pred, out)
print("verification ended.")
return {'status': 'success', 'pred': int(pred), 'out': out[0].tolist()}
except Exception as e:
print(e)
return {'status': 'fail'}
上述代码将用户传入的文件进行处理,然后输入到模型中,对于模型返回的预测结果进行格式化后再进行返回。
我们再设计一个简单的前端页面,这里用现在爆火的 ChatGPT 来完成,让其先设计一个有文件上传按钮的前端页面,然后再对这个页面进行美化。
将部分内容略作修改,简单的前端页面就做好了。
最后通过 Flask 框架设计一个接口就可以了:
@app.route("/verify", methods=['GET', 'POST'])
def getVerify():
fileStorage = request.get_data()
from model import verify
res = verify(fileStorage)
print('res:',res)
return res
后记
本文到此就结束了,文章细致的讲解了恶意文件静态检测模型的训练,验证评估以及接口设计。
以上就是 【项目实战】基于 MobileNetV3 实现恶意文件静态检测(下) 的全部内容了,希望本篇博文对大家有所帮助!
💖 我是 𝓼𝓲𝓭𝓲𝓸𝓽,期待你的关注,创作不易,请多多支持;
👍 公众号:sidiot的技术驿站;
- 点赞
- 收藏
- 关注作者
评论(0)