pytorch를 이용한 inference process 기본기 완벽하게 이해하기 & pytorch lightning

1. model.eval()

 

model.train()과 비슷하게 model을 evaluation mode로 바꾸는 것

 

 

 

evaluation 전에 반드시 설정하는 것이 좋다

 

batchnorm이나 dropout같은 것들이 training과정과 evalutation과정에서 다르게 동작해야해서 설정해주는 것이 의미 있다

 

실제로 쓰지 않으면.. 결과가 매우 다르기 때문에 잊지말고 사용해야한다.

 

 

2. with torch.no_grad()

 

evaluation은 단지 검증과정이다.

 

중간에 model의 parameter가 update된다면 문제가 있음

 

무슨 말이냐면 training 단계에서 update된 parameter를 가지고 evaluation하고 싶은거지

 

inference과정에서 따로 parameter를 또 update하면 말이 안되는거

 

with torch.no_grad():로 with와 함께 사용하여 감싸진 문장들은 grad_enabled인 parameter들을 전부 False로 바꿔준다

 

torch.no_grad()가 수행되면 __enter__가 호출되어 grad_enabled인 것을 False로 setting함

 

 

forward 과정에서 grad가 발생해도 requires_grad=False로 바뀌니 grad를 굳이 update하지 않는다??? grad 계산도 안할 것 같은데

 

----------------------------------------------------

 

경험상 반드시 사용하는게 좋음

 

with torch.no_grad():를 쓰면 이 구간에서는 gpu 메모리를 사용하지 않기 때문에 gpu가 터지질 않아..

 

안쓰면 터지는 경우가 있더라고 경험상

 

 

3. inference process 기본 base 코드

 

model.eval()이 앞에 있어야하고 with torch.no_grad():를 까먹지 말도록

 

correct = 0
total = 0

net.eval() #change eval model

with torch.no_grad():
    
    for data in testloader:
        
        images, labels = data
        
        outputs = net(images)
        
        _,predicted = torch.max(outputs.data,1)
        
        total += labels.size(0)
        
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct/total))

 

 

4. validation은 어떻게 할까?

 

inference process에서 testloader를 쓰지 않고 validloader를 쓰면 그것이 validation 과정이다.

 

inference는 학습이 끝나고나서, 평가하는 과정이고

 

validation은 train중에 model을 중간 평가하는 과정이므로 model을 학습시키지는 않는다는 점도 중요하긴하다

 

#val loop

with torch.no_grad():
    print("Calculating validation results...")
    
    model.eval()
    
    val_loss_items=[]
    val_acc_items=[]
    
    for val_batch in val_loader:
        
        inputs,labels = val_batch
        
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        outs = model(inputs)
        preds = torch.argmax(outs, dim = -1)

 

 

5. model을 저장하는 방법 & 저장한 모델을 불러오는 방법

 

https://tutorials.pytorch.kr/beginner/saving_loading_models.html

 

모델 저장하기 & 불러오기

Author: Matthew Inkawhich, 번역: 박정환,. 이 문서에서는 PyTorch 모델을 저장하고 불러오는 다양한 방법을 제공합니다. 이 문서 전체를 다 읽는 것도 좋은 방법이지만, 필요한 사용 예의 코드만 참고하

tutorials.pytorch.kr

 

 

특정 조건을 만족하면 torch.save()를 이용하여 model을 중간 저장하게 만들 수 있다

 

torch.save(model.state_dict(), "save_path".pth)

 

 

model.state_dict()로 model의 parameter만을 저장할 수 있음

 

보통은(적어도 나는..) .pth로 저장을 함

 

모델을 전부 저장할수도 있기는한데, model.state_dict()로 가중치만을 저장하는것이 보통이다

 

if val_loss < best_val_loss:
    
    print("New best model for val loss! saving the model..")
    torch.save(model.state_dict(), f"results/{name}/{epoch:03}_loss_{val_loss:4.2}.pth")
    best_val_loss = val_loss

if val_acc > best_val_acc:
    
    print("New best model for val accuracy! saving the model..")
    torch.save(model.state_dict(), f"results/{name}/{epoch:03}_accuracy_{val_acc:4.2%}.pth")
    best_val_acc = val_acc

 

이렇게 저장한 모델을 실제로 불러오고 싶다면..?

 

torch.load(<path>)로 pth 파일을 불러오고, model.load_state_dict() 함수에 불러온 pth파일을 덮어씌운다

 

model = ModelClass(*args,**kwargs) #실제 사용한 모델의 class

checkpoint = torch.load(PATH) #저장한 pth파일 load

model.load_state_dict(checkpoint) #불러온 가중치(pth파일)로 덮어씌우기

 

보통 gpu로 training을 하기 때문에, parameter가 gpu 모드로 되어있는 경우가 많다.

 

그런데 model class가 cpu라면 에러날 수 있는데... torch.load에서 map_location = torch.device('cpu') 옵션을 추가해준다

 

model = ModelClass(*args,**kwargs) #실제 사용한 모델의 class

checkpoint = torch.load(PATH, map_location = torch.device('cpu')) #저장한 pth파일 cpu모드로 load

model.load_state_dict(checkpoint) #불러온 가중치(pth파일)로 덮어씌우기

 

6. pytorch lightning

 

pytorch가 커스터마이징이 편하다거나… 그런다고 하더라도 실제 일할때는 생산성이나 재활용성이 되게 중요함

 

이럴때 baseline이 있으면 편한데

 

pytorch가 그런 것을 신경은 어느정도 썼다고 하더라도 그것조차도 귀찮다는 것

 

model의 train부터 inference까지 모든 과정을 하나의 class로 정의해버려서 마치 keras처럼 trainer.fit()하면 한번에 이루어지게 만듦

 

 

오른쪽에 긴 코드가 왼쪽의 3줄로 만들어짐

 

이렇게 생산성이 좋아서 요새 열심히 개발중이면서 주목받는다는데(2년전 이야기..)

 

지금은 어느정도까지 발전했는지는 모르겠다... 쓰긴 쓰는것 같던데

 

생산성 측면이나 편하다는 강점이 있긴하지만 pytorch의 근본적인 부분에서 정확히 이해를 하고 있는것이 중요하다

 

process들이 어떻게 동작하는지 기본을 충분히 이해하는 것이 뭐든 중요하다 이 말이야

 

그래야 자유롭게 응용도 가능하고 변형도 가능하거든

 

 

 

TAGS.

Comments