+18888889999
诚信为本:市场永远在变,诚信永远不变。

一起来学PyTorch——torch.nn优化器optim_1

你的位置: 首页 > 门徒平台资讯

一起来学PyTorch——torch.nn优化器optim_1

2024-03-12 11:44:26

\\quad 在学习的过程中,大家可能会感觉很蒙,其实现在做的就是看懂每一个小内容,等之后做一个具体的项目,就可以串联起来了。 优化器用于优化模型,加速收敛。


1.SGD方法
\\quad 随机梯度下降法,是指沿着梯度下降的方向求解极小值,一般可用于求解最小二乘问题。
\\quad \\quad \\quad \\quad \\quad \\quad \\quad \\quad \\quad \\quad\	heta_{t}=-\\eta * g_{t}

\\quad 其中 g_{t} 代表了参数的梯度, \\eta 代表了学习率, \	heta_{t} 代表了参数更新的幅度。
优点:
\\quad 每次采用的数据量小,会有更多次梯度更新。
缺点:
\\quad 一开始的学习率不好确定,容易出现局部最优解。

2.Adama方法
\\quad 利用梯度的一阶矩和二阶矩动态的估计调整每一个参数的学习率。

3.三层感知机例子-介绍基本的优化过程

from torch import nn
# 先定义一个三层感知机,激活函数使用Relu(小于0的,都转换为0)
class MLP(nn.Module):
    def __init__(self, in_dim, hid_dim1, hid_dim2, out_dim):
        super(MLP, self).__init__()
        #使用Sequential快速搭建三层感知机
        self.layer = nn.Sequential(
            # 第一层
            nn.Linear(in_dim, hid_dim1),
            nn.Relu(),
            # 第二层
            nn.Linear(hid_dim1, hid_dim2),
            nn.Relu(),
            # 第三层
            nn.Linear(hid_dim2, out_dim),
            nn.Relu()
            )
    def forward(self, x):
        y = self.layer(x)
        return y

# 进行模型的实例化
from torch import optim
from torch import nn
# 输入数据为28*28
# 隐藏层中:第一层需输入300个数据,第二层需输入200个数据
# 最后输出10个特征
model = MLP(28*28, 300, 200, 10)
# modeld的结构
#MLP(
#  (layer):Sequention(
#       (0):Linear(in_features=784, out_features=300, bias=True)   
#       (1):Relu()
#       (2):Linear(in_features=300, out_features=200, bias=True)
#       (3):Relu()
#       (4):Linear(in_features=200, out_features=10, bias=True)
#       (5):Relu()
#  )
#)
# 采用SGD优化器, 学习率设为0.01
optimizer = optim.SGD(params = model.parameters(), lr=0.01)
# 设置输出数据
data = torch.randn(10, 28*28)
# 输入模型后,输出数据
output = model(data)
# 因为最终输出10个特征,所以先设置10个特征
label = torch.Tensor([1, 0, 4, 7, 9, 3, 4, 5, 3, 2]).long()
# 求损失
criterion = nn.CrossEntropyLoss()
loss = criterion(output, label)
    # >>> loss
    # tensor(2.2762)
# 清空梯度,在每次优化前都要进行此操作
optimizer.zero_grad()
# 损失的反向传播
loss.backward()
# 利用优化器进行梯度更新
optimizer.step()

\\quad 大家只需要理解这些函数的功能就可以,之后我们会在详细的项目中,把这些功能应用起来。欢迎大家讨论!共同学习!

\\quad

地址:海南省海口市玉沙路58号  电话:0898-66889888  手机:18888889999
Copyright © 2012-2018 门徒-门徒娱乐-注册登录站 版权所有 ICP备案编:琼ICP备88889999号 

平台注册入口