用Python做图像分类
时至今日,简单的图像识别在部分场景已经有比较成熟的应用,今天我们分享一种基于pytorch的比较成熟的算法。其实算法的应用非常简单,主要是调试部署的过程。
-
调用torchvision下的训练好的resnet
from torchvision import models
显示找不到torchvision,估计是包还没有安装过。
!pip install torchvision
提示已经安装,但pip的版本偏低,尽管不是主要原因,我们还是顺带升级一下pip。
!python3 -m pip install --upgrade pip
升级成功,确认torchvision是否已经安装。
!pip show torchvision
显示我们确实已经安装了torchvision,再次尝试调用该包。
import torchvision
仍然显示不存在。
查看当前路径
import syssys.path
对比上面torchvision的安装路径可以发现,torchvision的安装路径并未包括在内,手动把该路径包含进去。
sys.path.append('/root/anaconda3/lib/python3.7/site-packages')
在次尝试调用torchvision
import torchvision
调用成功。
调用torchvision下面的模型,并查看这些模型
from torchvision import modelsdir(models)
可以看到模型的类别很多。
我们直接调用其下训练好的resnet网络
resnet = models.resnet101(pretrained=True)resnet
将其设置为评估模式
resnet.eval()
2. 识别图片
随便从网上找一张猫的图片,并读取
from PIL import Image
img = Image.open("cat.jpg")img
对图片进行处理
import torch
from torchvision import transforms
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)])
img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t, 0)
通过resnet对图片进行识别
output = resnet(batch_t)
output
其中权值最高的类型就是resnet对图片的识别。那么,每一类代表什么呢?我们需要找到分类表。resnet的分类表可以从PyTorch-Hub项目中找到,所以先克隆该项目
!git clone https://gitee.com/mirrors/PyTorch-Hub.git
然后从分类表中找出标签列表
with open('PyTorch-Hub/imagenet_classes.txt') as f:
labels = [line.strip() for line in f.readlines()]
找到输出值中最大的那个类型
_, index = torch.max(output, 1)
_, index
计算该类型的占比
percentage = torch.nn.functional.softmax(output, dim=1)[0] * 100
percentage
查看结果
labels[index[0]], percentage[index[0]].item()
可以看到识别成功,该图片有39.95%的可能是一只虎斑猫。
那么,我们可能还想知道剩下的60%的可能性是什么。对结果中的各类型进行排序,并输出前五个结果
_, indices = torch.sort(output, descending=True)
[(labels[idx], percentage[idx].item()) for idx in indices[0][:5]]
可以看到排在第二的是埃及猫,前两者相加达到75%左右。
- 点赞
- 收藏
- 关注作者
评论(0)