pytorch - 모델의 parameter 제대로 이해하기 재활치료
1. model이 가지는 parameter 확인하기
model에 정의된 modules가, 가지고 있는 forward 계산에 쓰일 parameter tensor가 저장되어 있음
.state_dict(), .parameters() 함수를 이용하여 저장된 parameter를 볼 수 있음
.state_dict()는 무엇이 무엇의 parameter인지 확인 가능
.parameters()는 그냥 parameter를 출력해서 뭐가 뭔지 확인은 어렵다
parameter는 weight와 bias로 이루어져있다는 것을 알 수 있다
2. parameter tensor
parameter는 tensor 기반의 class
그냥 tensor가 있고, grad를 가질 수 있는 parameter tensor라는 것이 있는거임.. 이거 되게 중요했어..
-----------------------------------------------------------------------------------------------------------------------------------------
layer는 보통 기본 속성으로 parameter인 layer.weight를 가진다.
parameter tensor와 그냥 tensor은 다르다.
layer의 parameter인 weight에 tensor를 넘겨줄려면 반드시 torch.nn.Parameter() 함수를 사용하여 parameter tensor로 만들고 넘겨줘야한다.
추가적으로 앞으로도 계속 언급할거지만 진짜 고수라면 tensor 조작할때는 항상 shape에 신경쓰고 있어야한다
reshape를 사용하여 shape를 변경하고 copy를 했다
torch.nn.Parameter()함수로 parameter tensor로 만들어서 weight에 넘겨줬다.
만약 parameter tensor로 변환안하고 넘기면 위와 같은 에러가 나타남
-----------------------------------------------------------------------------------------------------------------------------------------
data, grad, requires_grad라는 중요한 변수를 가짐
data는 parameter 값을 가질 것이고 grad는 backward에 의한 미분값
-----------------------------------------------------------------------------------------------------------------------------------------
layer.weight.data는 그냥 tensor이고
layer.weight는 parameter tensor
-----------------------------------------------------------------------------------------------------------------------------------------
requires_grad는 True나 False를 가지는데
requires_grad=True이면 training하면서 backward에 의한 미분값을 계산하면서 parameter를 업데이트함
반면 requires_grad=False이면 해당 parameter는 gradient를 계산하지 않는다.
requires_grad로 gradient 계산을 조정할 수 있다는것이 무슨 의미를 가지냐?
해당 layer를 freeze하여 다른 layer만 학습시키도록 하는 transfer learning 같은데서 활용할 수 있다
nn.Modules 하나만으로 grad,requires_grad까지 다가가서 parameter 학습에 이용된다는 사실
nn.Modules를 상속받은 layer마다 parameter가 존재해서 grad가 있으니 training하면서 parameter를 업데이트함
3. pythonic하다?
이런 과정들이 되게 어려운 로직이 아니라 사실 흔한 python의 일부임
state_dict()로 불러온 parameter 저장소는 사실 python의 dictionary
parameters()같은 경우는 generator
dictionary나 generator만 이해하고 있으면 쉽게 parameter들에 접근하여 변경하고 다양하게 응용이 가능하다
예를 들어 freeze하는거 가능할까???
각각의 기능에 의문을 가지면서 직접 공부하여 파생하는 상상력으로 나중에 다른 문제에 부딪히면 직접 멋지게 구현할 수 있을것
conv1.weight같은 경우 conv1은 고정된 약속이 아니고 그냥 내가 아무렇게나 지은 것
이런 자유도 때문에 에러가 쉽게 발생할 수도 있지만 그만큼 핸들링도 쉬움
state_dict()가 dictionary 구조를 가진다는 것을 안다면 두 모델 A,B가 비슷하다고 할 때 A,B의 parameter를 덮어씌우는 것이 아주 쉬움
그래서 load_state_dict()로 model의 weight를 덮어씌웠지..
'프로그래밍 > Pytorch' 카테고리의 다른 글
pytorch의 tensor를 plt.imshow()했더니 TypeError: Invalid shape for image data (0) | 2023.11.07 |
---|---|
pytorch - flatten과 averaging pooling, training 방법 기본기, layer 구성법 (0) | 2023.05.08 |
pytorch - model, nn.module 제대로 이해하기 재활치료 (0) | 2023.05.01 |
data augmentation & data generation 기본 개념 재활하기 (0) | 2023.04.28 |
pytorch에서 data augmentation은 어떻게 이해해야하는가 (0) | 2023.04.27 |