记一次图像分类排错

举报
Jack20 发表于 2025/12/24 20:21:17 2025/12/24
【摘要】 当CV模型-图像分类2.0进行模型训练时,模型训练报错[rank6]: UnboundLocalError: local variable 'avg' referenced before assignment时一、先搞懂报错核心原因这么说吧,UnboundLocalError的本质不是 CV 模型本身的问题,而是代码里的 “逻辑覆盖不全”:你代码中后续要使用avg变量(比如打印平均损失、返回...

当CV模型-图像分类2.0进行模型训练时,模型训练报错[rank6]: UnboundLocalError: local variable 'avg' referenced before assignment时

一、先搞懂报错核心原因

这么说吧,UnboundLocalError的本质不是 CV 模型本身的问题,而是代码里的 “逻辑覆盖不全”:
你代码中后续要使用avg变量(比如打印平均损失、返回平均准确率),但avg只在有数据的分支里赋值(比如训练循环执行时);当 rank6 节点的训练批次为空(比如分布式数据分片时,该 rank 分到的样本数为 0),赋值avg的逻辑没触发,后续引用avg就会直接报错。
在图像分类 2.0 模型训练中,这个问题最常出现在两个场景:
  1. 训练循环无数据:rank6 的 dataloader 为空(总样本数不能被 rank 数整除,最后一个 rank 没样本),计算平均损失 / 准确率的循环没执行,avg未赋值;
  2. 条件分支覆盖不全avg仅在if分支(比如 “有预测结果”)里赋值,else分支没赋值,却在后续直接引用。

二、分步解决:先定位,再修复

步骤 1:定位报错的具体代码行

先看报错堆栈(关键!),比如报错日志会显示:
 
[rank6]: UnboundLocalError: local variable 'avg' referenced before assignment
File "train.py", line 289, in train_one_epoch
logger.info(f"Rank {rank} epoch {epoch} avg acc: {avg_acc}")
 
这就明确了:train_one_epoch函数的 289 行引用了avg_acc,但这个变量可能没赋值。

步骤 2:分析图像分类训练中最典型的错误逻辑

咱们不妨看看图像分类训练里,计算平均指标的常见错误代码(这也是你大概率踩坑的地方):
 
def train_one_epoch(model, dataloader, criterion, optimizer, rank):
losses = []
correct = 0
total = 0
# 训练循环:若dataloader为空,这个循环完全不执行
for images, labels in dataloader:
# 前向传播(图像分类核心逻辑)
outputs = model(images)
loss = criterion(outputs, labels)
losses.append(loss.item())

# 计算准确率
_, preds = torch.max(outputs, 1)
correct += (preds == labels).sum().item()
total += labels.size(0)

# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()

# 问题1:若循环没执行,sum(losses)/len(losses)不会执行,avg_loss未赋值
avg_loss = sum(losses) / len(losses)
# 问题2:若total=0,correct/total会除以0,且avg_acc未赋值
avg_acc = correct / total

# 引用未赋值的avg_loss/avg_acc,rank6无数据时直接报错
print(f"Rank {rank} avg loss: {avg_loss}, avg acc: {avg_acc}")
return avg_loss, avg_acc
 

步骤 3:修复代码(核心!确保 avg 在所有场景下有赋值)

核心修复思路:初始化 avg 为默认值 + 加条件判断避免空数据 / 除以 0,修复后的代码如下(可直接替换):
 
def train_one_epoch(model, dataloader, criterion, optimizer, rank):
losses = []
correct = 0
total = 0
# 修复1:初始化avg变量,避免未赋值(关键!)
avg_loss = 0.0
avg_acc = 0.0

for images, labels in dataloader:
outputs = model(images)
loss = criterion(outputs, labels)
losses.append(loss.item())

_, preds = torch.max(outputs, 1)
correct += (preds == labels).sum().item()
total += labels.size(0)

optimizer.zero_grad()
loss.backward()
optimizer.step()

# 修复2:加条件判断,仅当有数据时计算平均值
if len(losses) > 0:
avg_loss = sum(losses) / len(losses)
else:
# 空数据时记录日志,方便排查
print(f"[Rank {rank}] Warning: 无训练数据,avg_loss设为0.0")

# 修复3:避免除以0,同时处理空数据
if total > 0:
avg_acc = correct / total
else:
print(f"[Rank {rank}] Warning: 无有效样本,avg_acc设为0.0")

# 此时avg_loss/avg_acc必有值,不会报错
print(f"Rank {rank} avg loss: {avg_loss}, avg acc: {avg_acc}")
return avg_loss, avg_acc
 

步骤 4:分布式训练额外处理(针对 rank6 无数据的根因)

rank6 节点无数据的本质是分布式采样器分片不均,可以在构建 dataloader 时补充配置,从根源避免空数据:
 
from torch.utils.data.distributed import DistributedSampler

# 构建分布式采样器(图像分类训练的标准配置)
train_sampler = DistributedSampler(
train_dataset, # 你的图像分类训练数据集
num_replicas=world_size, # rank总数(比如8)
rank=rank, # 当前rank
drop_last=True # 核心!丢弃最后一个不完整批次,避免某rank无数据
)

# 构建dataloader
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size, # 批次大小
sampler=train_sampler, # 使用分布式采样器
num_workers=4, # 数据加载线程
pin_memory=True # 加速GPU数据传输
)
 
  • drop_last=True:会丢弃总样本数不能被num_replicas×batch_size整除的最后一个批次,确保每个 rank 的 dataloader 都有数据;
  • 若不想丢弃样本:可以给数据集补充 “dummy 样本”(比如复制最后几个样本),让总样本数能被整除。

三、额外避坑点(CV 图像分类 2.0 训练专属)

  1. 验证集 / 测试集同样要处理:不仅训练循环,验证 / 测试阶段计算 avg 指标时,也要按上述方式初始化 + 条件判断,避免 eval 阶段报同样的错;
  2. 分布式指标同步:若需要跨 rank 同步 avg 指标(比如求所有 rank 的平均损失),空数据的 rank 用 0 值填充,再用torch.distributed.all_reduce同步;
  3. 日志打印要明确:给空数据的 rank 加日志(比如[Rank 6] 无训练数据),方便后续排查数据分片问题。

总结一下下

  1. 核心问题avg变量仅在有数据的分支赋值,rank6 节点无数据导致未赋值,触发UnboundLocalError
  2. 修复关键:先初始化avg为默认值(如 0.0),再通过条件判断避免空列表 / 除以 0;
  3. 分布式兜底:设置DistributedSampler(drop_last=True),避免某 rank 的 dataloader 为空。
【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。