pytorch를 이용한 inference process 기본기 완벽하게 이해하기 & pytorch lightning
1. model.eval()
model.train()과 비슷하게 model을 evaluation mode로 바꾸는 것
evaluation 전에 반드시 설정하는 것이 좋다
batchnorm이나 dropout같은 것들이 training과정과 evalutation과정에서 다르게 동작해야해서 설정해주는 것이 의미 있다
실제로 쓰지 않으면.. 결과가 매우 다르기 때문에 잊지말고 사용해야한다.
2. with torch.no_grad()
evaluation은 단지 검증과정이다.
중간에 model의 parameter가 update된다면 문제가 있음
무슨 말이냐면 training 단계에서 update된 parameter를 가지고 evaluation하고 싶은거지
inference과정에서 따로 parameter를 또 update하면 말이 안되는거
with torch.no_grad():로 with와 함께 사용하여 감싸진 문장들은 grad_enabled인 parameter들을 전부 False로 바꿔준다
forward 과정에서 grad가 발생해도 requires_grad=False로 바뀌니 grad를 굳이 update하지 않는다??? grad 계산도 안할 것 같은데
----------------------------------------------------
경험상 반드시 사용하는게 좋음
with torch.no_grad():를 쓰면 이 구간에서는 gpu 메모리를 사용하지 않기 때문에 gpu가 터지질 않아..
안쓰면 터지는 경우가 있더라고 경험상
3. inference process 기본 base 코드
model.eval()이 앞에 있어야하고 with torch.no_grad():를 까먹지 말도록
correct = 0
total = 0
net.eval() #change eval model
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_,predicted = torch.max(outputs.data,1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct/total))
4. validation은 어떻게 할까?
inference process에서 testloader를 쓰지 않고 validloader를 쓰면 그것이 validation 과정이다.
inference는 학습이 끝나고나서, 평가하는 과정이고
validation은 train중에 model을 중간 평가하는 과정이므로 model을 학습시키지는 않는다는 점도 중요하긴하다
#val loop
with torch.no_grad():
print("Calculating validation results...")
model.eval()
val_loss_items=[]
val_acc_items=[]
for val_batch in val_loader:
inputs,labels = val_batch
inputs = inputs.to(device)
labels = labels.to(device)
outs = model(inputs)
preds = torch.argmax(outs, dim = -1)
5. model을 저장하는 방법 & 저장한 모델을 불러오는 방법
https://tutorials.pytorch.kr/beginner/saving_loading_models.html
모델 저장하기 & 불러오기
Author: Matthew Inkawhich, 번역: 박정환,. 이 문서에서는 PyTorch 모델을 저장하고 불러오는 다양한 방법을 제공합니다. 이 문서 전체를 다 읽는 것도 좋은 방법이지만, 필요한 사용 예의 코드만 참고하
tutorials.pytorch.kr
특정 조건을 만족하면 torch.save()를 이용하여 model을 중간 저장하게 만들 수 있다
torch.save(model.state_dict(), "save_path".pth)
model.state_dict()로 model의 parameter만을 저장할 수 있음
보통은(적어도 나는..) .pth로 저장을 함
모델을 전부 저장할수도 있기는한데, model.state_dict()로 가중치만을 저장하는것이 보통이다
if val_loss < best_val_loss:
print("New best model for val loss! saving the model..")
torch.save(model.state_dict(), f"results/{name}/{epoch:03}_loss_{val_loss:4.2}.pth")
best_val_loss = val_loss
if val_acc > best_val_acc:
print("New best model for val accuracy! saving the model..")
torch.save(model.state_dict(), f"results/{name}/{epoch:03}_accuracy_{val_acc:4.2%}.pth")
best_val_acc = val_acc
이렇게 저장한 모델을 실제로 불러오고 싶다면..?
torch.load(<path>)로 pth 파일을 불러오고, model.load_state_dict() 함수에 불러온 pth파일을 덮어씌운다
model = ModelClass(*args,**kwargs) #실제 사용한 모델의 class
checkpoint = torch.load(PATH) #저장한 pth파일 load
model.load_state_dict(checkpoint) #불러온 가중치(pth파일)로 덮어씌우기
보통 gpu로 training을 하기 때문에, parameter가 gpu 모드로 되어있는 경우가 많다.
그런데 model class가 cpu라면 에러날 수 있는데... torch.load에서 map_location = torch.device('cpu') 옵션을 추가해준다
model = ModelClass(*args,**kwargs) #실제 사용한 모델의 class
checkpoint = torch.load(PATH, map_location = torch.device('cpu')) #저장한 pth파일 cpu모드로 load
model.load_state_dict(checkpoint) #불러온 가중치(pth파일)로 덮어씌우기
6. pytorch lightning
pytorch가 커스터마이징이 편하다거나… 그런다고 하더라도 실제 일할때는 생산성이나 재활용성이 되게 중요함
이럴때 baseline이 있으면 편한데
pytorch가 그런 것을 신경은 어느정도 썼다고 하더라도 그것조차도 귀찮다는 것
model의 train부터 inference까지 모든 과정을 하나의 class로 정의해버려서 마치 keras처럼 trainer.fit()하면 한번에 이루어지게 만듦
이렇게 생산성이 좋아서 요새 열심히 개발중이면서 주목받는다는데(2년전 이야기..)
지금은 어느정도까지 발전했는지는 모르겠다... 쓰긴 쓰는것 같던데
생산성 측면이나 편하다는 강점이 있긴하지만 pytorch의 근본적인 부분에서 정확히 이해를 하고 있는것이 중요하다
process들이 어떻게 동작하는지 기본을 충분히 이해하는 것이 뭐든 중요하다 이 말이야
그래야 자유롭게 응용도 가능하고 변형도 가능하거든
'프로그래밍 > Pytorch' 카테고리의 다른 글
pytorch에서 optimizer & metric 기본개념 재활 (0) | 2023.04.26 |
---|---|
pytorch에서 loss 기본개념 재활 (0) | 2023.04.26 |
pytorch를 이용한 training process 기본기 완벽하게 이해하기 (0) | 2023.04.24 |
개발자가 숫자를 0부터 세야하는 이유 - loss function 사용할때 class는 0부터 시작하기(cuda error) (0) | 2023.04.21 |
pytorch tensor 다루기 재활치료 2편 - view, squeeze, unsqueeze, type, cat, stack, ones_like, zeros_like, inplace - (0) | 2023.03.11 |