Pytorch에서 learning rate scheduler 사용하는 방법 알기
1. 개요
learning rate는 model train 성능을 결정하는 중요한 요소
동일한 learning rate를 사용하여 처음부터 끝까지 학습을 할 수도 있지만,
초반에는 큰 learning rate를 사용하여 빠르게 최적값에 가다가, 후반에는 작은 learning rate를 사용하여 미세조정을 할 수도 있다.
기본적인 원리는 지정한 epoch 스텝마다 learning rate에 gamma를 곱한 값을 새로운 learning rate라 하고 다음 epoch을 돌린다.
2. 기본적인 사용법
Pytorch에서는 다양한 learning rate scheduler를 지원하고 있다.
기본적으로 학습시에 batch마다 optimizer.step()을 하고 나서,
batch마다 learning rate를 바꾸고 싶다면 optimizer.step() 다음에 scheduler.step()
epoch마다 learning rate를 바꾸고 싶다면, 1epoch가 끝나고 나서 scheduler.step()
원하는 타이밍에 사용가능
optimizer.step() 다음에 scheduler.step()을 사용하는 것이 좋다.
scheduler.step() 다음에 optimizer.step()을 하면 warning이 남
import torch
import torch.nn as nn
import torch.optim as optim
# loss
loss = nn.MSELoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-3)
#scheduler
scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
lr_lambda=lambda epoch: 0.95 ** epoch,
last_epoch=-1,
verbose=False)
epochs=100
for epoch in range(epochs):
for i, (data) in enumerate(data_loader):
x_data, y_data = data
optimizer.zero_grad()
estimated_y = model(x_data)
loss = loss(y_data, estimated_y)
loss.backward()
optimizer.step()
#scheduler.step() #batch마다 learning rate update
scheduler.step() # epoch마다 learning rate update
3. 예시
torch.optim.lr_scheduler에서 여러가지 지원하고 있다
https://pytorch.org/docs/stable/optim.html#module-torch.optim.lr_scheduler
각각 서로 다른 특징으로 learning rate를 변화시킨다
단순히 비율만큼 곱해서 계속 감소시키거나, 감소시켰다가 증가시키거나...
보통 자주 쓰는 건?
CosineAnnealingLR, MultiStepLR, StepLR, CyclicLR 이런걸 쓰는 듯?
이론상으로는 성능향상이 없을때 learning rate를 바꾸는 ReduceLROnPlateau가 좋아보이긴함
필요할때 각 scheduler의 parameter를 찾아보면서 사용하면 된다
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=0)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1)
https://sanghyu.tistory.com/113
'프로그래밍 > Pytorch' 카테고리의 다른 글
Pytorch model forward에서 에러나는 경우 대처하기(input output model shape print해보기) (0) | 2024.08.25 |
---|---|
Pytorch에서 padding sequence vs. packed sequence 차이 이해하고 구현하기 (0) | 2024.04.19 |
torch.where()로 tensor내 특정 원소의 위치를 찾기 (0) | 2024.04.14 |
Pytorch의 computational graph와 backward()에 대해 이해하기 (0) | 2024.04.13 |
Pytorch에서 두 tensor가 서로 같은지 비교하고 싶다면? (0) | 2024.04.11 |