딥러닝 모델에서 가중치를 초기화하는 방법(weight initialization)
신경망을 학습할 때 가중치를 초기화하고 update해야하는데 어떻게 초기화해야 학습에 좋을까?
단순히 0으로 시작해버린다면 gradient가 0으로 계산되는 경우가 많을 것.
너무 큰 값으로 시작한다면, 계산된 activation이 너무 커질 것.
단순한 예시로 위와 같은 신경망에서, bias = 0이라고 가정한다면...
z = W1(W2(W3...(Wn(X))..))로 계산되는데, n개의 weight들의 곱에 input X의 곱으로 중간 output이 계산된다.
만약 weight들의 원소가 1보다 작은 값들이라면.. weight들을 곱할수록 0에 가까워진다.
하지만, 1보다 조금이라도 크다면 weight들을 많이 곱할수록 매우 커진다.
그러다보니 weight들을 처음에 어떤 값들로 시작하는게 딥러닝 학습에 좋을지 연구자들이 많이 연구하였다.
가장 많이 알려진 방법으로 Xavier initialization, He initialization
1) Xavier initialization
https://proceedings.mlr.press/v9/glorot10a.html
"understanding the difficulty of training deep feedforward neural networks'에서 제시하였다.
activation이 linear이고 weight가 independently, input feature의 분산이 동일하다는 가정하에
weight의 좋은 분산을 수학적으로 유도하였다.
$n_{in}$이 layer의 input의 개수이고 $n_{out}$이 layer의 output의 개수라고 하자.
이 때 가중치 $W$를 다음과 같은 normal이나 uniform에서 랜덤으로 생성하여 초기화하는 방법이다.
$$W\sim N(0,\frac{2}{n_{in}+n_{out}})$$
$$W\sim U(-\sqrt{\frac{6}{n_{in}+n_{out}}},\sqrt{\frac{6}{n_{in}+n_{out}}})$$
sigmoid나 tanh같이 x가 0 근처일때, 선형에 가까운 함수인 경우 효과적이다.
2) He initialization
https://arxiv.org/abs/1502.01852
'Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification"에서
Xavier initialization이 sigmoid나 tanh에서 효과적이지만, ReLU 계열의 activation에서는 효과적이지 않다고 지적했다.
그리고 나름대로 ReLU계열에서 효과적인 initialization을 유도했는데 그 결과가 다음과 같다.
$n_{in}$이 layer의 input의 개수라고 한다면
가중치 $W$를 다음과 같은 normal이나 uniform에서 랜덤으로 생성하여 초기화하는 방법이다.
$$W\sim N(0,\frac{2}{n_{in}})$$
$$W\sim U(-\sqrt{\frac{6}{n_{in}}},\sqrt{\frac{6}{n_{in}}})$$
xavier의 변형으로 layer의 output 개수를 제거했다
일반적으로 He방법이 조금 더 빠르다고 알려져있다
ReLU일 때 특히 좋다고 한다
3) weight initialization은 왜 중요할까?
GPT-2는 왜 weight initialization으로 layer의 깊이에 따른 차별적인 scaling을 했을까?
layer가 엄청 깊게 쌓여져있는데 딥러닝의 모든 계산은 layer의 matrix multiplication으로 이루어진다.
element들의 분산이 점점 달라지는 현상이 깊을수록 누적되고 결국에는 exploding하거나 vanishing하는 현상이 나타난다.
그래서 weight initialization을 layer의 깊이에 따라 차별적으로 scaling하여
위층에서 weight의 영향력을 줄이고 exploding/vanishing현상을 안정화시키고자 했다.
"딥러닝의 forward 계산으로 output이 exploding되거나 vanishing되는 현상을 막고자 한다"
4) 일반적인 관례
보통 activation으로 ReLU 계열을 사용하다보니 He initialization이 자주 사용된다.
tanh의 경우 Xavier initialization이 효과적일 수 있다.
보통 normal distribution에서 초기화한다.
일반적으로 bias는 0으로 초기화하는 것이 좋다고 한다.
5) 코드 구현 예시
torch.nn.init에 다양한 가중치 초기화 함수를 제공함
https://pytorch.org/docs/stable/nn.init.html
다음과 같이 MNIST data를 준비하고
#library
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import random
#download mnist
mnist_train = dsets.MNIST(root = 'MNIST_data/', train = True,
transform = transforms.ToTensor(),
download = True)
mnist_test = dsets.MNIST(root = 'MNIST_data/', train = False,
transform = transforms.ToTensor(),
download = True)
#dataloader
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size = 32,
shuffle = True, drop_last = True)
#왜 test를 shuffle을 true로 했었을까
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size = 32,
shuffle = True, drop_last = True)
#hyperparameter
device = 'cuda' if torch.cuda.is_available() else 'cpu'
random.seed(777)
torch.manual_seed(777)
if device == 'cuda':
torch.cuda.manual_seed_all(777)
training_epochs = 15
batch_size = 100
torch.nn.init.normal_(linear.weight)은 linear의 weight를 normal 분포에서 랜덤하게 뽑아 채워준다
#https://pytorch.org/docs/stable/nn.init.html
#initialization weight tensor
linear = torch.nn.Linear(784,10,bias=True).to(device)
torch.nn.init.normal_(linear.weight)
total_batch = len(data_loader)
for epoch in range(training_epochs):
avg_cost = 0
for x,y in data_loader:
x = x.view(-1,784).to(device)
y = y.to(device)
optimizer.zero_grad()
hypothesis = linear(x)
cost = criterion(hypothesis, y)
cost.backward()
optimizer.step()
avg_cost += cost/total_batch
print('epoch: ', '%04d' %(epoch+1), 'cost = ', '{:.9f}'.format(avg_cost))
print('learning finished')
epoch: 0001 cost = 1.096136212
epoch: 0002 cost = 1.005576730
epoch: 0003 cost = 0.999088705
epoch: 0004 cost = 1.072223544
epoch: 0005 cost = 1.018894672
epoch: 0006 cost = 1.011124969
epoch: 0007 cost = 1.013777733
epoch: 0008 cost = 1.012180567
epoch: 0009 cost = 1.022216916
epoch: 0010 cost = 1.027872682
epoch: 0011 cost = 1.039777279
epoch: 0012 cost = 1.056721330
epoch: 0013 cost = 1.044611216
epoch: 0014 cost = 1.011985064
epoch: 0015 cost = 1.035102725
learning finished
그리고 다음과 같이 테스트해보면..
with torch.no_grad():
x_test = mnist_test.test_data.view(-1,784).float().to(device)
y_test = mnist_test.test_labels.to(device)
prediction = linear(x_test)
correct_prediction = torch.argmax(prediction,1) == y_test
accuracy = correct_prediction.float().mean()
print('accuracy: ', accuracy.item())
r = random.randint(0,len(mnist_test)-1)
x_single_data = mnist_test.test_data[r:r+1].view(-1,784).float().to(device)
y_single_data = mnist_test.test_labels[r:r+1].to(device)
print('label:',y_single_data.item())
single_prediction = linear(x_single_data)
print('prediction: ', torch.argmax(single_prediction, 1).item())
accuracy: 0.7666000127792358
label: 9
prediction: 9
몇번 써봤는데... 엄청나게 성능을 끌어올린다? 그 정도는 아닌듯
https://sonstory.tistory.com/71#Xavier%20%EC%B4%88%EA%B8%B0%ED%99%94-1