🗒️P4:Pytorch实战:猴痘病识别
00 分钟
2023-5-19
2023-5-20
type
status
date
slug
summary
tags
category
icon
password

1 引言

本文主要实现了构建 CNN 网络实现对猴痘病图片的识别。

1.1 训练营要求

  • 保存训练过程中效果最好的模型参数
  • 加载最佳模型参数识别本地的一张照片
  • 调整网络结构使得测试集 accuracy 正确率达到 88%

1.2 训练记录

  1. 使用原文章代码进行训练
  1. 使用数据增强技术
  1. 修改网络结构
  1. 修改学习率
各个训练记录的结果与改进效果将在文章的结果分析部分详细讲解。
结果:最终 val_accuracy 达到 90%。
附:由于许多内容在其他文章中都有所涉及,因此在此片笔记中不会过多讲解,但是会附上链接

2 前期工作

2.1 设置 GPU

2.2 导入数据

2.2.1 从本地读取数据

输出结果为:['Others', 'Monkeypox']

2.2.2 数据增强

2.3 标签映射

在DatasetFolder中,class_to_idx是一个字典,将类别名映射到类别标签(从0开始),其中类别名是文件夹的名称,类别标签是与之相关联的数字。
为什么要做标签映射呢?
  • 将类别名映射到类别标签是因为在训练深度学习模型时,通常使用类别标签来表示每个样本的类别。
  • 在训练模型时,输入数据被转换为张量,并且每个张量的标签是一个数字,表示与之相关联的类别。
  • 类别标签使得模型可以根据真实标签和预测标签之间的误差来更新模型权重,从而使模型学习到如何将输入数据映射到正确的输出标签。
输出结果为:{'Monkeypox': 0, 'Others': 1}

2.4 划分数据集

3 构建 CNN 网络

对于一般的CNN网络来说,都是由特征提取网络和分类网络构成,其中特征提取网络用于提取图片的特征,分类网络用于将图片进行分类。
关于函数介绍这一部分的内容其实在此前博客中已经有所讲解,但是因其重要性,故不在讲解。
notion image
由于此项目只是个简单的二分类问题,且图片的特征也较少,所以只用这个简单的网络结构就已经得到了不错的训练结果。

4 训练模型

4.1 设置超参数

4.2 编写训练、测试函数

在深度学习中,训练和测试是两个不同的阶段,因此训练函数和测试函数也有不同的实现方式。
  • 训练函数的主要目的是通过反向传播算法来更新模型的参数,以最小化损失函数。在训练函数中,模型参数被反复更新,以逐渐优化模型在训练数据上的性能。因此,在训练函数中,我们通常会使用优化器来更新模型的参数,并对模型在训练数据上的表现进行评估。此外,训练函数还需要计算训练集上的准确率和损失,以便我们可以了解模型在训练数据上的表现。
  • 测试函数的主要目的是评估模型在新数据上的性能。在测试函数中,我们不会更新模型的参数,而是根据测试数据来评估模型的泛化能力。因此,在测试函数中,我们通常不需要使用优化器来更新模型的参数。此外,在测试函数中,我们需要计算测试集上的准确率和损失,以便我们可以了解模型在新数据上的表现。
因此,训练函数和测试函数的区别在于它们的目的和实现方式。训练函数的主要目的是优化模型在训练数据上的性能,并且在训练过程中更新模型的参数,以逐渐优化模型的性能。而测试函数的主要目的是评估模型在新数据上的性能,因此不会更新模型的参数。此外,训练函数和测试函数在计算准确率和损失时,需要使用不同的数据集(训练集和测试集)来进行评估。

4.4 正式训练

训练过程截图如下:
notion image

5 结果可视化

notion image
从上图的训练结果来看,似乎还没有收敛,继续训练 val_accuracy 的正确率似乎还能够继续提高。

5.1 指定图片预测

6 保存并加载模型

将模型的状态字典保存和加载具有几个重要的用途:
  • 训练时的模型检查点: 在深度学习模型训练期间,通常会定期保存模型的状态字典(例如,每个 epoch)以跟踪模型的进度。如果训练过程中断(例如,由于系统崩溃或停电),则可以从最后保存的检查点继续训练,而不是从头开始。这有助于节省时间和计算资源。
  • 与他人共享训练好的模型: 一旦深度学习模型已经训练好,可以保存其状态字典并与他人共享,以便他们可以用于推断或进一步微调。这对于想要将他们的模型与更广泛的社区分享的研究人员或想要向其客户分发预训练模型的公司尤其有用。
  • 部署训练好的模型到生产环境: 将训练好的深度学习模型部署到生产环境时,可以保存其状态字典并将其加载到生产服务器上。这使您无需在生产服务器上重新训练模型,从而节省时间和计算资源。
  • 实验和调试: 在尝试不同的模型架构或超参数时,可以保存每个训练模型的状态字典,并稍后加载进行进一步分析。这使您可以轻松快速地比较不同模型或超参数的性能。
代码讲解:
torch.save(model.state_dict(), PATH) # 将模型的状态字典保存到指定的文件中。
  • 在这段代码中,你首先在PATH变量中定义了保存模型的状态字典的文件名和路径。然后,你使用 torch.save() 函数将模型的状态字典保存到指定文件中。model.state_dict()方法返回一个包含模型的参数和缓冲区的字典,然后将其保存到PATH指定的文件中。
model.load_state_dict(torch.load(PATH, map_location=device)) # 从指定的文件中加载模型的状态字典
  • 在这段代码中,你使用 torch.load() 函数从 PATH 指定的文件中加载模型的状态字典。然后 model.load_state_dict() 方法将状态字典加载到模型中。Themap_location 参数用来指定模型应该被加载到哪个设备上。如果没有指定 map_location 参数,torch.load() 函数将尝试把模型加载到它保存的同一设备上。如果保存的模型是在GPU上训练的,但你想把它加载到 CPU 上,你可以设置 map_location='cpu'。
在执行 model.load_state_dict() 方法后,模型的参数和缓冲区会从保存的状态字典中加载,模型就可以用于推理或进一步训练了。
 
 

评论
  • Twikoo
  • Cusdis