目录
老饼讲解:一步一步上手深度学习

【例子】训练一个MLP--数值预测

作者 : 老饼 发表日期 : 2025-09-29 20:31:48 更新日期 : 2026-05-20 12:26:12
老饼讲解-简单易懂,干货满满,爽过嗦螺!


好了,都说MLP好用,那这节我们用就用pytorch来训练一个MLP神经网络吧!

用它来解决什么问题好呢?都说它能拟合任意关系,那么我们不妨试试sin函数吧,看看它能不能拟合sin函数。

一、MLP神经网络代码例子-曲线拟合

好了,闲话不多说,直接上代码,开车喽!

如下所示,在sin函数[-5,5]之间采集20个数据,我们需要训练一个三层MLP神经网络,来拟合sin函数x与y的关系。

sin函数的采样数据

 由于样本较为简单,我们不妨将三层MLP神经网络设为4个隐节点,并使用均方差作为损失函数。这里简单地用梯度下降算法它进行训练,具体代码实现如下:

# 本代码用于展示:训练一个MLP用于sin函数拟合
# 本代码来自《老饼讲解-深度学习》www.bbblearn.com
import torch
import matplotlib.pyplot as plt 
torch.manual_seed(99)
# -----计算网络输出:前馈式计算------
def forward(w1,b1,w2,b2,x):                                   
    return w2@torch.tanh(w1@x+b1)+b2

# -----计算损失函数: 使用均方差------
def loss(y,py):
    return ((y-py)**2).mean()

# -----训练数据----------------------
x = torch.linspace(-5,5,20).reshape(1,20)                      # 在[-5,5]之间生成20个数作为x
y = torch.sin(x)                                               # 模型的输出值y

#-----------训练模型-----------------
in_num  = x.shape[0]                                            # 输入个数
out_num = y.shape[0]                                            # 输出个数
hn  = 4                                                         # 隐节点个数
w1  = torch.randn([hn,in_num],requires_grad=True)               # 初始化输入层到隐层的权重w1
b1  = torch.randn([hn,1],requires_grad=True)                    # 初始化隐层的阈值b1
w2  = torch.randn([out_num,hn],requires_grad=True)              # 初始化隐层到输出层的权重w2
b2  = torch.randn([out_num,1],requires_grad=True)               # 初始化输出层的阈值b2

lr = 0.01                                                       # 学习率
for i in range(10000):                                          # 训练10000步
    py = forward(w1,b1,w2,b2,x)                                 # 计算网络的输出
    L  = loss(y,py)                                             # 计算损失函数
    print('第',str(i),'轮,mse:',L.item())                      # 打印当前损失函数值
    L.backward()                                                # 用损失函数更新模型参数的梯度
    w1.data=w1.data-lr*w1.grad                                  # 更新模型系数w1
    b1.data=b1.data-lr*b1.grad                                  # 更新模型系数b1
    w2.data=w2.data-lr*w2.grad                                  # 更新模型系数w2
    b2.data=b2.data-lr*b2.grad                                  # 更新模型系数b2
    w1.grad.zero_()                                             # 清空w1梯度,以便下次backward
    b1.grad.zero_()                                             # 清空b1梯度,以便下次backward
    w2.grad.zero_()                                             # 清空w2梯度,以便下次backward
    b2.grad.zero_()                                             # 清空b2梯度,以便下次backward
    if(L.item()<0.005):                                         # 如果误差达到要求
        break                                                   # 退出训练
px = torch.linspace(-5,5,100).reshape(1,100)                    # 测试数据,用于绘制网络的拟合曲线    
py = forward(w1,b1,w2,b2,px).detach().numpy()                   # 网络的预测值
plt.scatter(x, y)                                               # 绘制样本
plt.plot(px[0,:],py[0,:])                                       # 绘制拟合曲线  
plt.show()                                                      # 展示画布
print("\n模型参数:")                                            # 打印模型参数
print('w1:',w1)                                                 # 打印w1
print('b1:',b1)                                                 # 打印b1
print('w2:',w2)                                                 # 打印w2
print('b2:',b2)                                                 # 打印b2

运行结果如下:

训练好的权重阈值

拟合效果

可以看到,MLP对sin函数的关系拟合得挺好的~说明它成功学习到x和y的关系了~

二、代码解说

经过前面的学习,本文的代码没什么难度,下面简单解说一下~

1.1.数据与参数初始化

好了,直接从14行开始看代码片段-1

其中,15-16是生成sin函数的采样数据,用于训练,19-25是变量的赋值与参数的初始化,以前已经讲过很多次了,没什么好说的。

1.2.模型训练代码

好了,下面是模型训练的代码代码片段-2

其实也没什么好说的,就是用梯度下降来训练MLP的参数。

第29行先计算模型的输出,其中forward是第7行定义的模型的输出计算函数,在计算出输出后,把它放到第11行定义的loss函数中计算损失值,最后第32行对损失值进行backward,从而更新参数的梯度。

然后,33-40行,对参数按负梯度进行更新,并清空梯度。最后41-42行,检查误差是否足够小,如果已经很小,就退出训练。 

1.3. 模型训练结果打印

最后就是打印模型的训练效果了,包括模型的拟合曲线和模型参数代码片段-3

第43-44行,计算模型在之间的预测值。

第45-47行,画出模型的拟合曲线。

第48-52行,打印模型的参数,w1,b1,w2,b2等等。

总结

好了,这节其实在代码上没什么好说的,只是展示一个MLP用于数值预测时的具体实现方法,以及它的效果。进一步更具体的感受"MLP拟合任意关系"这一特色。



图标 评论
添加评论