利用神经网络实现手写字体识别
本文主要是讲如何使用 PyTorch 实现手写数字识别,包括MNIST数据集加载处理、神经网络模型定义、训练并评估模型。
1.Mnist数据集
Mnist数据集是美国国家标准与技术研究院收集的关于手写数字扫描图像及其对应识别数字的数据集。该数据集分为两部分:
第一部分包含60000幅28x28大小的灰度图及对应识别数字,用作训练数据,这些图像扫描自250个人的手写样本。
第二部分包含10000幅28x28大小的灰度图及对应识别数字,用作测试数据,为了保证测试结果,这些图像来自另外一批人。
部分训练数据:
部分训练数据对应识别数字:
[5 0 4 1 9 2 1 3 1 4 3 5 3 6 1 7 2 8 6 9 4 0 9 1]
部分测试数据:部分测试数据对应识别数字:
[3 8 6 9 6 4 5 3 8 4 5 2 3 8 4 8 1 5 0 5 9 7 4 1]
2.Mnist数据集数据处理
1.数据加载
1 | def load_mnist(): |
1 | # 训练数据集 训练结果集 测试数据集 验证结果集 |
2.数据打印
1 | # 训练样本50000 每个样本784个像素点数据 |
3.数据转换
张量(Tensor):PyTorch 中的基础数据结构,类似于矩阵或多维数组,用于表示和存储数据。张量不仅支持数学运算,还能够支持深度学习模型中所需的自动求导功能(通过反向传播计算梯度),还能支持 GPU 加速。
1 | # 将Mnist数据集的数据转化为可供PyTorch使用的张量 |
4.数据打包和批量加载
创建训练和验证数据的 DataLoader 对象
- TensorDataset:将输入数据和标签组合成一个数据集,可以方便地用来处理训练集和验证集
- DataLoader:用来批量加载数据,支持多种功能,比如按批次加载数据、随机打乱数据、并行加载等
1 | def get_tensor_dataset(batch): |
3.模型定义
1 | from torch import nn |
模型的输入的是图像数据,输出是个10分类。模型使用比较经典的基于全连接层的网络结构。
1.全连接层(Fully Connected Layer)
公式:
$$
f_c = w x + b
$$
- $w$:权重矩阵(weights),每个输入特征对应一个输出特征的权重
- $x$:输入向量(input),通常是当前层的输入数据
- $b$:偏置项(bias),是一个额外的参数,添加到线性变换中,以帮助模型更好地拟合数据
- $f_c$:输出向量(output),表示该层经过线性变换后的结果,是通过加权和加上偏置计算出来的。
在PyTorch中全连接层是通过 nn.Linear(in_features, out_features)
这个方法来实现的
- in_features:输入特征的数量,即输入数据的维度
- out_features:输出特征的数量,即该层的神经元个数
nn.Linear会自动处理权重矩阵和偏置项的初始化、更新和应用
2.Dropout层
用于防止过拟合,在每次前向传播中,随机丢弃 50% 的神经元。
3.前向传播(Forward Propagation)
输入数据(如图像、文本等)通过网络中各层的传递和计算,得到最终模型的输出的过程。
前向传播作用:
- 根据输入数据生成模型的预测结果
- 对比预测结果和真实结果,计算出损失,用于反向传播更新权重
4.反向传播(Back Propagation)
通过计算损失函数的梯度来调整神经网络的参数(即权重和偏置),从而最小化模型的损失,下次迭代更准确。
反向传播的主要作用:计算每层参数(权重、偏置)的梯度以及更新模型参数,优化模型性能
在pytorch中反向传播有现成的API,实现起来比较方便。
5.激活函数(Activation Function)
决定了每个神经元的输出,使得神经网络能够学习到输入和输出之间复杂的非线性关系。如果没有激活函数,神经网络无论有多少层,它仍然只是一个线性变换,无法解决复杂的问题,比如图像分类、自然语言处理问题等。
常见的激活函数有:Sigmoid、Tanh、ReLU、Softmax 等
1.ReLU激活函数
$$
ReLU(x) = max(0, x)
$$
特点:
- 计算简单:仅需比较和取最大值操作,计算效率高
- 非线性:虽然形式简单,但能够引入非线性,使神经网络可以学习复杂模式
- 稀疏激活:负值输出为0,可让网络中的部分神经元保持“沉默”,提升模型的稀疏表示能力
优点:
- 缓解梯度消失问题:在正区间梯度恒为1,避免了深层网络因梯度连乘导致的梯度消失,优于Sigmoid/Tanh
- 加速收敛:相比Sigmoid/Tanh,ReLU的梯度更稳定,训练速度通常更快
- 生物学合理性:类似神经元的“全有或全无”激活机制
缺点:
- Dead ReLU问题:如果神经元输出恒为0(如初始化不良或学习率过高),梯度无法更新,导致永久性“死亡”
- 非零中心化:输出均值非零,可能影响梯度下降的效率(但影响通常较小)
2.Sigmoid激活函数
一般用在二分类问题中,通常用于模型的输出层,Sigmoid 函数将输入值(通常是一个实数)映射到一个范围为 (0, 1) 的值,输出类似于概率,可以解释为某个类别的概率。
特点:
- 输出范围(0,1),适合于将输出解释为概率,特别是用于二分类任务
- 平滑且连续,适合用于模型的激活函数,因为它能提供稳定的梯度,帮助优化过程
- 单调递增,随着输入值的增大,输出值也不断增大
- 非线性,使得神经网络能够学习输入和输出之间的复杂关系
- 中心对称性,可以处理二分类问题中的“平衡”预测
- 梯度计算,可以通过函数值本身计算出来,计算上较为简单
优点:
- 概率输出:天然适合二分类问题(如逻辑回归)
- 可微性:梯度计算简单,适用于反向传播
缺点:
- 梯度消失:当输入 ∣x∣ 较大时,梯度接近0,导致深层网络难以训练。 例如:σ(5)≈1,此时梯度 σ′(5)≈0
- 非零中心化:输出均值>0,可能使梯度更新呈“锯齿状”,影响收敛速度
- 计算成本:涉及指数运算,比ReLU慢
4.模型训练、验证
1.损失函数(Loss Function)
模型训练中用于衡量模型预测与真实值之间差距的重要工具,选择什么损失函数取决于具体的任务类型。
回归任务常用的损失函数:
- 均方误差损失(MSE Loss):用于预测房价、温度等连续值,对大的预测误差更为敏感,适合于要求精确度较高的回归问题
- 平均绝对误差损失(MAE Loss):适用于异常值(outliers)不太敏感时
分类任务常用的损失函数:
- 交叉熵损失(Cross-Entropy Loss):模型输出通常经过 Softmax 激活,表示每个类别的概率
- 二分类交叉熵损失(Binary Cross-Entropy Loss):适用于二分类任务(如垃圾邮件分类、图像中的物体检测等),模型输出通常经过 Sigmoid 激活,表示某一类的概率
- 多标签二分类交叉熵损失(Multi-label Binary Cross-Entropy Loss):适用于每个样本有多个标签的情况,例如图像可以同时包含多个物体
生成模型损失函数:
- 对抗性损失(Adversarial Loss):在生成对抗网络(GAN)中,生成器和判别器通过相互对抗来训练
为了衡量模型预测的概率分布与实际标签的概率分布之间的差异,手写字体识别用的损失函数是交叉熵损失。
2.优化器(Optimizer)
在训练时,更新网络中参数(如权重和偏置)的算法。它的作用是通过计算损失函数的梯度,并根据梯度来调整模型的参数,使得模型的预测越来越接近真实值,从而最小化损失函数。
特点:
- 学习率(Learning Rate):决定了每次更新时参数的调整幅度。学习率过大会导致训练不稳定,学习率过小则可能导致收敛速度过慢
- 动量(Momentum):在更新时引入历史梯度的影响,可以帮助优化器跳出局部最小值,快速收敛
- 自适应学习率:根据参数的更新历史调整每个参数的学习率。这可以帮助优化器在不同的训练阶段更好地调整步长
- 梯度裁剪(Gradient Clipping):当梯度过大时,优化器会裁剪梯度值,防止梯度爆炸,保持训练稳定性
- 加速收敛:通过使用动量、Adagrad、Adam 等优化技术,优化器可以更快地收敛,提高训练效率
常用的优化器包括 SGD、Adam、RMSprop、Adagrad 等,它们各有优缺点,可以根据任务的需要选择合适的优化器。其中Adam 是最常用的优化器之一,它结合了动量和自适应学习率,能够加速收敛并减少调整学习率的工作。
手写字体识别用到的两种优化器:
- 随机梯度下降(Stochastic Gradient Descent, SGD):每次更新只使用一个训练样本(或小批量数据),使得参数更新更加频繁,可以更快地收敛
- 相比于传统的梯度下降,计算效率更高
- 可能会导致参数更新方向的噪声,收敛速度较慢,但有时能跳出局部最小值
- 需要设定批量大小(batch size)
- Adam(Adaptive Moment Estimation):结合了动量(Momentum)和自适应学习率(Adagrad),它能够有效地调整学习率,适应不同的梯度方向
- 可以自动调整不同参数的学习率
- 非常适合处理大规模数据和非平稳目标
下面是在两种优化器SGD和Adam下的训练过程:
1 | # optim.SGD(model.parameters(), lr=0.001) |
很明显在手写字体识别模型训练中,使用Adam产生的损失要小于SGD优化器,损失越小代表模型识别的准确率越高。
1 | # 计算损失 |
模型定义以及主方法调用:
1 | model = Mnist_NN() # 定义训练模型 |
5.模型测试
用测试集的数据对训练完的模型去做验证,验证模型的准确率:
1 | correct = 0 |
SGD优化器输出:
测试集10000张图片正确率:87 %
Adam优化器输出:
测试集10000张图片正确率:97 %
6.总结
Mnist数据集手写字体识别案例作为PyTorch库的HelloWorld,讲了在训练神经网络过程中的一些基本概念和PyTorch库的基础使用,包括
- 数据预处理,图像像素值转化为[0, 1](Mnist数据集已经做了);
- 数据转化为PyTorch方便使用的Tensor数据格式;
- 使用DataLoader进行数据加载,方便高效训练;
- 有多个线性层组成并且后面跟随一个激活函数的全连接网络,最后输出0~9十分类;
- 使用交叉熵损失函数计算损失;
- 使用SGD和Adam优化器调整模型参数
7.备注
环境:
- mac: 15.2
- python: 3.12.4
- pytorch: 2.5.1
- matplotlib: 3.8.4
- numpy: 1.26.4
数据集:
https://github.com/keychankc/dl_code_for_blog/tree/main/001_nn_digital_recognition/data/mnist
完整代码:
https://github.com/keychankc/dl_code_for_blog/tree/main/001_nn_digital_recognition