机器学习之卷积神经网络使用cifar10数据集和alexnet网络模型训练分类模型,安装labelimg,以及报错ERROR
【摘要】 使用cifar10数据集和alexnet网络模型训练分类模型 下载cifar10数据集 代码:import torchvisionimport torchtransform = torchvision.transforms.Compose( [torchvision.transforms.ToTensor(), torchvision.transforms.Resize(22...
使用cifar10数据集和alexnet网络模型训练分类模型
下载cifar10数据集
代码:
import torchvision
import torch
transform = torchvision.transforms.Compose(
[torchvision.transforms.ToTensor(),
torchvision.transforms.Resize(224)]
)
train_set = torchvision.datasets.CIFAR10(root='./',download=False,train=True,transform=transform)
test_set = torchvision.datasets.CIFAR10(root='./',download=False,train=False,transform=transform)
train_loader = torch.utils.data.DataLoader(train_set,batch_size=8,shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set,batch_size=8,shuffle=True)
class Alexnet(torch.nn.Module): #1080 2080
def __init__(self,num_classes=10):
super(Alexnet,self).__init__()
net = torchvision.models.alexnet(pretrained=False) #迁移学习
net.classifier = torch.nn.Sequential()
self.features = net
self.classifier = torch.nn.Sequential(
torch.nn.Dropout(0.3),
torch.nn.Linear(256 * 6 * 6, 4096),
torch.nn.ReLU(inplace=True),
torch.nn.Dropout(0.3),
torch.nn.Linear(4096, 4096),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(4096, num_classes),
)
def forward(self,x):
x = self.features(x)
x = x.view(x.size(0),-1)
x = self.classifier(x)
return x
device = torch.device('cpu')
net = Alexnet().to(device)
loss_func = torch.nn.CrossEntropyLoss().to(device)
optim = torch.optim.Adam(net.parameters(),lr=0.001)
net.train()
for epoch in range(10):
for step,(x,y) in enumerate(train_loader): # 28*28*1 32*32*3
x,y = x.to(device),y.to(device)
output = net(x)
loss = loss_func(output,y)
optim.zero_grad()
loss.backward()
optim.step()
print("epoch:",epoch,'loss:',loss)
安装labelimg,以及报错
目标检测标注工具:labelimg
安装 pip install labelimg
使用 labelimg
报错
ERROR: spyder 4.1.4 requires pyqtwebengine<5.13; python_version >= “3”, which is not installed. ERROR: spyder 4.1.4 has requirement pyqt5<5.13; python_version >= “3”, but you’ll have pyqt5 5.15.6 which is incompatible
版本不匹配问题
打开Anaconda Prompt
使用命令安装Spyder
pip install spyder==4.1.4
或者
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple/ spyder==4.1.4
使用 labelimg
在安装环境下找到labelimg.exe复制到桌面
打开
打开一张图片
【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)