PyTorch 和 Albumentations 实现图像分类(猫狗大战)
目录
使用Albumentations定义训练和验证数据集的转换函数
摘要
本示例说明如何使用Albumentations 对图像进行分类。 我们将使用``猫与狗''数据集。 任务是检测图像是否包含猫或狗。
导入所需的库
from collections import defaultdict
import copy
import random
import os
import shutil
from urllib.request import urlretrieve
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
cudnn.benchmark = True
下载数据集并解压缩
class TqdmUpTo(tqdm):
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
def download_url(url, filepath):
directory = os.path.dirname(os.path.abspath(filepath))
os.makedirs(directory, exist_ok=True)
if os.path.exists(filepath):
print("Filepath already exists. Skipping download.")
return
with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=os.path.basename(filepath)) as t:
urlretrieve(url, filename=filepath, reporthook=t.update_to, data=None)
t.total = t.n
def extract_archive(filepath):
extract_dir = os.path.dirname(os.path.abspath(filepath))
shutil.unpack_archive(filepath, extract_dir)
设置下载数据集的目录
dataset_directory = "datasets/cats-vs-dogs"
下载数据集并解压
filepath = os.path.join(dataset_directory, "kagglecatsanddogs_3367a.zip")
download_url(
url="https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip",
filepath=filepath,
)
extract_archive(filepath)
切分训练集、验证集和测试集
数据集中的某些文件已损坏,因此我们将仅使用OpenCV可以正确加载的那些图像文件。 我们将使用20000张图像进行训练,使用4936张图像进行验证,并使用10张图像进行测试。
root_directory = os.path.join(dataset_directory, "PetImages")
cat_directory = os.path.join(root_directory, "Cat")
dog_directory = os.path.join(root_directory, "Dog")
cat_images_filepaths = sorted([os.path.join(cat_directory, f) for f in os.listdir(cat_directory)])
dog_images_filepaths = sorted([os.path.join(dog_directory, f) for f in os.listdir(dog_directory)])
images_filepaths = [*cat_images_filepaths, *dog_images_filepaths]
correct_images_filepaths = [i for i in images_filepaths if cv2.imread(i) is not None]
random.seed(42)
random.shuffle(correct_images_filepaths)
train_images_filepaths = correct_images_filepaths[:20000]
val_images_filepaths = correct_images_filepaths[20000:-10]
test_images_filepaths = correct_images_filepaths[-10:]
print(len(train_images_filepaths), len(val_images_filepaths), len(test_images_filepaths))
20000 4936 10
定义一个可视化图像及其标签的函数
让我们定义一个函数,该函数将获取图像文件路径及其标签的列表,并在网格中将其可视化。 正确的标签为绿色,错误预测的标签为红色。
def display_image_grid(images_filepaths, predicted_labels=(), cols=5):
rows = len(images_filepaths) // cols
figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(12, 6))
for i, image_filepath in enumerate(images_filepaths):
image = cv2.imread(image_filepath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
true_label = os.path.normpath(image_filepath).split(os.sep)[-2]
predicted_label = predicted_labels[i] if predicted_labels else true_label
color = "green" if true_label == predicted_label else "red"
ax.ravel()[i].imshow(image)
ax.ravel()[i].set_title(predicted_label, color=color)
ax.ravel()[i].set_axis_off()
plt.tight_layout()
plt.show()
display_image_grid(test_images_filepaths)
定义一个PyTorch数据集类
接下来,我们定义一个PyTorch数据集。 如果您不熟悉PyTorch数据集,请参阅本教程-https://pytorch.org/tutorials/beginner/data_loading_tutorial.html。 输出任务是二进制分类-模型需要预测图像包含猫还是狗。 我们的标签将标记图像包含猫的可能性。 因此,带有猫的图像的正确标签将为1.0,带有狗的图像的正确标签将为0.0。 __init__将收到一个可选的转换参数。 它是“白化”增强管道的转换功能。 然后在__getitem__中,Dataset类将使用该函数来扩大图像并返回正确的标签。
class CatsVsDogsDataset(Dataset):
def __init__(self, images_filepaths, transform=None):
self.images_filepaths = images_filepaths
self.transform = transform
def __len__(self):
return len(self.images_filepaths)
def __getitem__(self, idx):
image_filepath = self.images_filepaths[idx]
image = cv2.imread(image_filepath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if os.path.normpath(image_filepath).split(os.sep)[-2] == "Cat":
label = 1.0
else:
label = 0.0
if self.transform is not None:
image = self.transform(image=image)["image"]
return image, label
使用Albumentations定义训练和验证数据集的转换函数
我们使用Albumentation定义用于训练和验证数据集的扩充管道。在这两个管道中,我们首先调整输入图像的大小,因此其最小尺寸为160px,然后进行128px x 128px的裁剪。对于训练数据集,我们还对该作物应用更多的增强。接下来,我们将对图像进行归一化。我们首先将图像的所有像素值除以255,因此每个像素的值将在[0.0,1.0]范围内。然后,我们将减去平均像素值,然后将其除以标准偏差。增强流水线的均值和标准差取自ImageNet数据集。尽管如此,它们仍然可以很好地传输到``猫与狗''数据集。之后,我们将应用ToTensorV2将Tombs数组转换为PyTorch张量,该张量将用作神经网络的输入。 请注意,在验证管道中,我们将使用A.CenterCrop而不是A.RandomCrop,因为我们希望验证结果具有确定性(这样就不会依赖于作物的随机位置)。
train_transform = A.Compose(
[
A.SmallestMaxSize(max_size=160),
A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
A.RandomCrop(height=128, width=128),
A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
A.RandomBrightnessContrast(p=0.5),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
]
)
train_dataset = CatsVsDogsDataset(images_filepaths=train_images_filepaths, transform=train_transform)
val_transform = A.Compose(
[
A.SmallestMaxSize(max_size=160),
A.CenterCrop(height=128, width=128),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
]
)
val_dataset = CatsVsDogsDataset(images_filepaths=val_images_filepaths, transform=val_transform)
还让我们定义一个函数,该函数采用数据集并可视化应用于同一图像的不同增强。
def visualize_augmentations(dataset, idx=0, samples=10, cols=5):
dataset = copy.deepcopy(dataset)
dataset.transform = A.Compose([t for t in dataset.transform if not isinstance(t, (A.Normalize, ToTensorV2))])
rows = samples // cols
figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(12, 6))
for i in range(samples):
image, _ = dataset[idx]
ax.ravel()[i].imshow(image)
ax.ravel()[i].set_axis_off()
plt.tight_layout()
plt.show()
random.seed(42)
visualize_augmentations(train_dataset)
定义训练辅助方法
我们定义了训练的辅助方法。 compute_accuracy接受模型预测和真实标签,并将返回这些预测的准确性。 MetricMonitor有助于跟踪训练和验证过程中的准确性或损失等指标
def calculate_accuracy(output, target):
output = torch.sigmoid(output) >= 0.5
target = target == 1.0
return torch.true_divide((target == output).sum(dim=0), output.size(0)).item()
class MetricMonitor:
def __init__(self, float_precision=3):
self.float_precision = float_precision
self.reset()
def reset(self):
self.metrics = defaultdict(lambda: {"val": 0, "count": 0, "avg": 0})
def update(self, metric_name, val):
metric = self.metrics[metric_name]
metric["val"] += val
metric["count"] += 1
metric["avg"] = metric["val"] / metric["count"]
def __str__(self):
return " | ".join(
[
"{metric_name}: {avg:.{float_precision}f}".format(
metric_name=metric_name, avg=metric["avg"], float_precision=self.float_precision
)
for (metric_name, metric) in self.metrics.items()
]
)
定义训练参数
在这里,我们定义了一些训练参数,例如模型架构,学习率,batch_size,epochs等
params = {
"model": "resnet50",
"device": "cuda",
"lr": 0.001,
"batch_size": 64,
"num_workers": 4,
"epochs": 10,
}
训练和验证
model = getattr(models, params["model"])(pretrained=False, num_classes=1,)
model = model.to(params["device"])
criterion = nn.BCEWithLogitsLoss().to(params["device"])
optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])
train_loader = DataLoader(
train_dataset, batch_size=params["batch_size"], shuffle=True, num_workers=params["num_workers"], pin_memory=True,
)
val_loader = DataLoader(
val_dataset, batch_size=params["batch_size"], shuffle=False, num_workers=params["num_workers"], pin_memory=True,
)
def train(train_loader, model, criterion, optimizer, epoch, params):
metric_monitor = MetricMonitor()
model.train()
stream = tqdm(train_loader)
for i, (images, target) in enumerate(stream, start=1):
images = images.to(params["device"], non_blocking=True)
target = target.to(params["device"], non_blocking=True).float().view(-1, 1)
output = model(images)
loss = criterion(output, target)
accuracy = calculate_accuracy(output, target)
metric_monitor.update("Loss", loss.item())
metric_monitor.update("Accuracy", accuracy)
optimizer.zero_grad()
loss.backward()
optimizer.step()
stream.set_description(
"Epoch: {epoch}. Train. {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor)
)
def validate(val_loader, model, criterion, epoch, params):
metric_monitor = MetricMonitor()
model.eval()
stream = tqdm(val_loader)
with torch.no_grad():
for i, (images, target) in enumerate(stream, start=1):
images = images.to(params["device"], non_blocking=True)
target = target.to(params["device"], non_blocking=True).float().view(-1, 1)
output = model(images)
loss = criterion(output, target)
accuracy = calculate_accuracy(output, target)
metric_monitor.update("Loss", loss.item())
metric_monitor.update("Accuracy", accuracy)
stream.set_description(
"Epoch: {epoch}. Validation. {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor)
)
训练模型
for epoch in range(1, params["epochs"] + 1):
train(train_loader, model, criterion, optimizer, epoch, params)
validate(val_loader, model, criterion, epoch, params)
Epoch: 1. Train. Loss: 0.700 | Accuracy: 0.598: 100%|██████████| 313/313 [00:38<00:00, 8.04it/s]
Epoch: 1. Validation. Loss: 0.684 | Accuracy: 0.663: 100%|██████████| 78/78 [00:03<00:00, 23.46it/s]
Epoch: 2. Train. Loss: 0.611 | Accuracy: 0.675: 100%|██████████| 313/313 [00:37<00:00, 8.24it/s]
Epoch: 2. Validation. Loss: 0.581 | Accuracy: 0.689: 100%|██████████| 78/78 [00:03<00:00, 23.25it/s]
Epoch: 3. Train. Loss: 0.513 | Accuracy: 0.752: 100%|██████████| 313/313 [00:38<00:00, 8.22it/s]
Epoch: 3. Validation. Loss: 0.408 | Accuracy: 0.818: 100%|██████████| 78/78 [00:03<00:00, 23.61it/s]
Epoch: 4. Train. Loss: 0.440 | Accuracy: 0.796: 100%|██████████| 313/313 [00:37<00:00, 8.24it/s]
Epoch: 4. Validation. Loss: 0.374 | Accuracy: 0.829: 100%|██████████| 78/78 [00:03<00:00, 22.89it/s]
Epoch: 5. Train. Loss: 0.391 | Accuracy: 0.821: 100%|██████████| 313/313 [00:37<00:00, 8.25it/s]
Epoch: 5. Validation. Loss: 0.345 | Accuracy: 0.853: 100%|██████████| 78/78 [00:03<00:00, 23.03it/s]
Epoch: 6. Train. Loss: 0.343 | Accuracy: 0.845: 100%|██████████| 313/313 [00:38<00:00, 8.22it/s]
Epoch: 6. Validation. Loss: 0.304 | Accuracy: 0.861: 100%|██████████| 78/78 [00:03<00:00, 23.88it/s]
Epoch: 7. Train. Loss: 0.312 | Accuracy: 0.858: 100%|██████████| 313/313 [00:38<00:00, 8.23it/s]
Epoch: 7. Validation. Loss: 0.259 | Accuracy: 0.886: 100%|██████████| 78/78 [00:03<00:00, 23.29it/s]
Epoch: 8. Train. Loss: 0.284 | Accuracy: 0.875: 100%|██████████| 313/313 [00:38<00:00, 8.21it/s]
Epoch: 8. Validation. Loss: 0.304 | Accuracy: 0.882: 100%|██████████| 78/78 [00:03<00:00, 23.81it/s]
Epoch: 9. Train. Loss: 0.265 | Accuracy: 0.884: 100%|██████████| 313/313 [00:38<00:00, 8.18it/s]
Epoch: 9. Validation. Loss: 0.255 | Accuracy: 0.888: 100%|██████████| 78/78 [00:03<00:00, 23.78it/s]
Epoch: 10. Train. Loss: 0.248 | Accuracy: 0.890: 100%|██████████| 313/313 [00:38<00:00, 8.21it/s]
Epoch: 10. Validation. Loss: 0.222 | Accuracy: 0.909: 100%|██████████| 78/78 [00:03<00:00, 23.90it/s]
预测图像标签并可视化这些预测
现在我们有了训练好的模型,因此让我们尝试预测一些图像的标签,看看这些预测是否正确。 首先我们制作CatsVsDogsInferenceDatasetPyTorch数据集。 它的代码类似于训练和验证数据集,但是推理数据集仅返回图像,而不返回关联的标签(因为在现实世界中,我们通常无权访问真实标签,并希望使用我们训练有素的模型来推断它们 )。
class CatsVsDogsInferenceDataset(Dataset):
def __init__(self, images_filepaths, transform=None):
self.images_filepaths = images_filepaths
self.transform = transform
def __len__(self):
return len(self.images_filepaths)
def __getitem__(self, idx):
image_filepath = self.images_filepaths[idx]
image = cv2.imread(image_filepath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.transform is not None:
image = self.transform(image=image)["image"]
return image
test_transform = A.Compose(
[
A.SmallestMaxSize(max_size=160),
A.CenterCrop(height=128, width=128),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
]
)
test_dataset = CatsVsDogsInferenceDataset(images_filepaths=test_images_filepaths, transform=test_transform)
test_loader = DataLoader(
test_dataset, batch_size=params["batch_size"], shuffle=False, num_workers=params["num_workers"], pin_memory=True,
)
model = model.eval()
predicted_labels = []
with torch.no_grad():
for images in test_loader:
images = images.to(params["device"], non_blocking=True)
output = model(images)
predictions = (torch.sigmoid(output) >= 0.5)[:, 0].cpu().numpy()
predicted_labels += ["Cat" if is_cat else "Dog" for is_cat in predictions]
display_image_grid(test_images_filepaths, predicted_labels)
完整代码
上面的代码没有问题,但是顺序有点乱,直接训练有错误,我重新整理,并做了适当的修改。
from collections import defaultdict
import copy
import random
import os
import shutil
from urllib.request import urlretrieve
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
cudnn.benchmark = True
class TqdmUpTo(tqdm):
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
def download_url(url, filepath):
directory = os.path.dirname(os.path.abspath(filepath))
os.makedirs(directory, exist_ok=True)
if os.path.exists(filepath):
print("Filepath already exists. Skipping download.")
return
with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=os.path.basename(filepath)) as t:
urlretrieve(url, filename=filepath, reporthook=t.update_to, data=None)
t.total = t.n
def extract_archive(filepath):
extract_dir = os.path.dirname(os.path.abspath(filepath))
shutil.unpack_archive(filepath, extract_dir)
def display_image_grid(images_filepaths, predicted_labels=(), cols=5):
rows = len(images_filepaths) // cols
figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(12, 6))
for i, image_filepath in enumerate(images_filepaths):
image = cv2.imread(image_filepath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
true_label = os.path.normpath(image_filepath).split(os.sep)[-2]
predicted_label = predicted_labels[i] if predicted_labels else true_label
color = "green" if true_label == predicted_label else "red"
ax.ravel()[i].imshow(image)
ax.ravel()[i].set_title(predicted_label, color=color)
ax.ravel()[i].set_axis_off()
plt.tight_layout()
plt.show()
class CatsVsDogsDataset(Dataset):
def __init__(self, images_filepaths, transform=None):
self.images_filepaths = images_filepaths
self.transform = transform
def __len__(self):
return len(self.images_filepaths)
def __getitem__(self, idx):
image_filepath = self.images_filepaths[idx]
image = cv2.imread(image_filepath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if os.path.normpath(image_filepath).split(os.sep)[-2] == "Cat":
label = 1.0
else:
label = 0.0
if self.transform is not None:
image = self.transform(image=image)["image"]
return image, label
def visualize_augmentations(dataset, idx=0, samples=10, cols=5):
dataset = copy.deepcopy(dataset)
dataset.transform = A.Compose([t for t in dataset.transform if not isinstance(t, (A.Normalize, ToTensorV2))])
rows = samples // cols
figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(12, 6))
for i in range(samples):
image, _ = dataset[idx]
ax.ravel()[i].imshow(image)
ax.ravel()[i].set_axis_off()
plt.tight_layout()
plt.show()
def calculate_accuracy(output, target):
output = torch.sigmoid(output) >= 0.5
target = target == 1.0
return torch.true_divide((target == output).sum(dim=0), output.size(0)).item()
class MetricMonitor:
def __init__(self, float_precision=3):
self.float_precision = float_precision
self.reset()
def reset(self):
self.metrics = defaultdict(lambda: {"val": 0, "count": 0, "avg": 0})
def update(self, metric_name, val):
metric = self.metrics[metric_name]
metric["val"] += val
metric["count"] += 1
metric["avg"] = metric["val"] / metric["count"]
def __str__(self):
return " | ".join(
[
"{metric_name}: {avg:.{float_precision}f}".format(
metric_name=metric_name, avg=metric["avg"], float_precision=self.float_precision
)
for (metric_name, metric) in self.metrics.items()
]
)
def train(train_loader, model, criterion, optimizer, epoch, params):
metric_monitor = MetricMonitor()
model.train()
stream = tqdm(train_loader)
for i, (images, target) in enumerate(stream, start=1):
images = images.to(params["device"], non_blocking=True)
target = target.to(params["device"], non_blocking=True).float().view(-1, 1)
output = model(images)
loss = criterion(output, target)
accuracy = calculate_accuracy(output, target)
metric_monitor.update("Loss", loss.item())
metric_monitor.update("Accuracy", accuracy)
optimizer.zero_grad()
loss.backward()
optimizer.step()
stream.set_description(
"Epoch: {epoch}. Train. {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor)
)
def validate(val_loader, model, criterion, epoch, params):
metric_monitor = MetricMonitor()
model.eval()
stream = tqdm(val_loader)
with torch.no_grad():
for i, (images, target) in enumerate(stream, start=1):
images = images.to(params["device"], non_blocking=True)
target = target.to(params["device"], non_blocking=True).float().view(-1, 1)
output = model(images)
loss = criterion(output, target)
accuracy = calculate_accuracy(output, target)
metric_monitor.update("Loss", loss.item())
metric_monitor.update("Accuracy", accuracy)
stream.set_description(
"Epoch: {epoch}. Validation. {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor)
)
class CatsVsDogsInferenceDataset(Dataset):
def __init__(self, images_filepaths, transform=None):
self.images_filepaths = images_filepaths
self.transform = transform
def __len__(self):
return len(self.images_filepaths)
def __getitem__(self, idx):
image_filepath = self.images_filepaths[idx]
image = cv2.imread(image_filepath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.transform is not None:
image = self.transform(image=image)["image"]
return image
if __name__ == '__main__':
dataset_directory = "datasets/cats-vs-dogs"
filepath = os.path.join(dataset_directory, "kagglecatsanddogs_3367a.zip")
download_url(
url="https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip",
filepath=filepath,
)
extract_archive(filepath)
root_directory = os.path.join(dataset_directory, "PetImages")
cat_directory = os.path.join(root_directory, "Cat")
dog_directory = os.path.join(root_directory, "Dog")
cat_images_filepaths = sorted([os.path.join(cat_directory, f) for f in os.listdir(cat_directory)])
dog_images_filepaths = sorted([os.path.join(dog_directory, f) for f in os.listdir(dog_directory)])
images_filepaths = [*cat_images_filepaths, *dog_images_filepaths]
correct_images_filepaths = [i for i in images_filepaths if cv2.imread(i) is not None]
random.seed(42)
random.shuffle(correct_images_filepaths)
train_images_filepaths = correct_images_filepaths[:20000]
val_images_filepaths = correct_images_filepaths[20000:-10]
test_images_filepaths = correct_images_filepaths[-10:]
print(len(train_images_filepaths), len(val_images_filepaths), len(test_images_filepaths))
display_image_grid(test_images_filepaths)
train_transform = A.Compose(
[
A.SmallestMaxSize(max_size=160),
A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
A.RandomCrop(height=128, width=128),
A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
A.RandomBrightnessContrast(p=0.5),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
]
)
train_dataset = CatsVsDogsDataset(images_filepaths=train_images_filepaths, transform=train_transform)
random.seed(42)
visualize_augmentations(train_dataset)
val_transform = A.Compose(
[
A.SmallestMaxSize(max_size=160),
A.CenterCrop(height=128, width=128),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
]
)
val_dataset = CatsVsDogsDataset(images_filepaths=val_images_filepaths, transform=val_transform)
params = {
"model": "resnet50",
"device": "cuda",
"lr": 0.001,
"batch_size": 64,
"num_workers": 4,
"epochs": 10,
}
model = getattr(models, params["model"])(pretrained=False, num_classes=1, )
model = model.to(params["device"])
criterion = nn.BCEWithLogitsLoss().to(params["device"])
optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])
train_loader = DataLoader(
train_dataset, batch_size=params["batch_size"], shuffle=True, num_workers=params["num_workers"],
pin_memory=True,
)
val_loader = DataLoader(
val_dataset, batch_size=params["batch_size"], shuffle=False, num_workers=params["num_workers"], pin_memory=True,
)
for epoch in range(1, params["epochs"] + 1):
train(train_loader, model, criterion, optimizer, epoch, params)
validate(val_loader, model, criterion, epoch, params)
test_transform = A.Compose(
[
A.SmallestMaxSize(max_size=160),
A.CenterCrop(height=128, width=128),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
]
)
test_dataset = CatsVsDogsInferenceDataset(images_filepaths=test_images_filepaths, transform=test_transform)
test_loader = DataLoader(
test_dataset, batch_size=params["batch_size"], shuffle=False, num_workers=params["num_workers"],
pin_memory=True,
)
model = model.eval()
predicted_labels = []
with torch.no_grad():
for images in test_loader:
images = images.to(params["device"], non_blocking=True)
output = model(images)
predictions = (torch.sigmoid(output) >= 0.5)[:, 0].cpu().numpy()
predicted_labels += ["Cat" if is_cat else "Dog" for is_cat in predictions]
display_image_grid(test_images_filepaths, predicted_labels)
华为开发者空间发布
让每位开发者拥有一台云主机
- 点赞
- 收藏
- 关注作者
评论(0)