pytorch에서 loss 기본개념 재활

1. loss

 

input data로부터 forward를 통해 계산한 예측된 결과 output

 

input의 정답 label인 target과의 차이가 loss이다.

 

loss는 error, cost로 불리기도 한다.

 

backward 과정에 의해 loss가 update된다.

 

output과 target의 차이를 어떻게 정의할 것인가?

 

문제와 task 목적에 따라 제곱오차, cross entropy 등 여러가지로 정의할 수 있다.

 

loss에 따라 차이는 바뀔 것이고 class마다도 다를 수 있는데

 

loss의 선택에 따라 training중 parameter 업데이트 과정도 달라지므로 신중하게 선택해야한다.

 

input의 forward 계산으로 얻은 label 예측값인 output과 정답 label인 target의 차이가 loss

 

 

2. nn.Module

 

loss class도 nn.Module을 상속받는다…

 

따라서 __init__와 forward를 반드시 정의해야한다

 

nn 패키지에서 보면 종류도 엄청 많다. nn.CrossEntropyLoss, nn.L1loss, nn.KLDivLoss, nn.MSELoss 등등

 

loss 내부에 __init__과 forward가 있는 모습

 

 

3. loss.backward()에 의한 chain

 

trainloader에 의해 데이터를 batch단위로 뽑아오고 model(input)으로 output을 계산해

 

미리 정의한 loss 함수인 criterion(output,label)으로 loss를 계산함

 

criterion = nn.MSELoss()나 criterion= nn.CrossEntropyLoss() 같이 loss 함수의 관용적인 표현?느낌

 

 

training하면서 model의 module이 가지고 있는 parameter들이 업데이트 되기를 바라는 것인데

 

그래야 다른 input을 집어넣으면서 loss가 단축되면서 좋은 학습이 될거임

 

그런데 코드만 보면 모델의 parameter가 업데이트 되는게 맞을까?? 이런 생각이 든다

 

심지어 loss.backward()하면 업데이트가 된다고는 하는데

 

loss랑 model의 parameter가 무슨 상관이길래 업데이트가 된다는 것일까?

 

model과 loss 모두 nn.Module을 상속받아 같은 행동이 가능하다는 것에 답이 있다

 

model에서 input이 들어와 forward를 calling하면 output이 나오고

 

output이 loss로 정의한 criterion(output, label)에서 input으로 다시 들어가면서 loss의 forward가 calling된다

 

그래서 input부터 loss까지 nn.Module의 forward에 의한 하나의 chain이 완성된다.

 

loss도 nn.Module을 받는 class이면서 model의 output을 input으로 받으므로 model의 forward와 연결 될 수 있다

 

 

input부터 model을 넘어 forward에 의해 gradient계산이 진행되는데

 

loss까지 chain이 완성되니까 loss.backward()하나만으로 model의 parameter가 업데이트 될 수 있다

 

 

4. 나만의 loss 만들기

 

output과 target의 차이를 만드는 과정에 특별한 상상을 더하여 변형함

 

특정한 case에는 더욱 penalty를 주고 다른 case에는 penalty를 덜 준다든지 여러가지 변형이 가능

 

loss도 nn.Module을 상속받는 class이므로 dataset이나 model을 만드는 것처럼 커스터마이징으로 만들 수도 있다

 

경험상 잘 설계하면 생각보다 효과적이다.

 

 

5. focal loss

 

class imbalance 문제에 효과적일 수 있음

 

특정 class의 data 수가 너무 적으면 보통은 그것에 대해 모델이 확률을 적게 부여할 수 있음

 

그래서 맞추기 쉬운 data가 많은 class에 대해선 틀렸다고 너무 많은 penalty를 주기보다는 조금만 줘서 언제든지 업데이트할 수 있도록 만든다.

 

반면 확률이 낮은, data가 부족한 class에 대해선 변칙적으로 penalty를 많이 줘서 지금 아니면 안된다는 생각으로 빠르게 학습시키도록 만든다

 

 

6. label smoothing loss

 

이미지는 보통 그 특성상 다양한 feature들이 들어가있어서 100% 특정한 하나의 label이라고 말하기는 어렵다

 

이런 특징을 반영하여 이미지의 label을 one hot vector [0,1]로 표현하기보다는 원소의 합은 1이지만 더욱 soft한 표현 [0.2,0.8]로 표현

 

이러한 label을 사용하여 계산한 loss가 label smoothing loss이다.

 

 

개에 focusing 되어 있어서 개라고 label했는데 100% 개 이미지라고 말할 수 있을까?

 

옆에 고양이가 있어서 100% 개라고 부르기는 조금 아쉽다

 

80% 개 20% 고양이정도?

TAGS.

Comments