Unexpected key(s) in state_dict: "model.electra.embeddings.position_ids". 에러 해결하기
예전에 학습한 모델을 다시 써볼려고 하는데
import pytorch_lightning as pl
import torch
import torch.nn as nn
from transformers import AutoTokenizer,AutoModelForSequenceClassification,BertForSequenceClassification
device = "cuda" if torch.cuda.is_available() else "cpu"
class TextClassificationStudentModule(pl.LightningModule):
def __init__(self, config, labels, lr=5e-4, alpha=1.0):
super().__init__()
self.save_hyperparameters()
if isinstance(config, str):
self.model = AutoModelForSequenceClassification.from_pretrained(
config, num_labels=len(labels)
)
else:
self.model = BertForSequenceClassification(config)
self.multiclass = len(labels) > 1
self.criterion = nn.CrossEntropyLoss() if self.multiclass else nn.BCELoss()
self.soft_label_criterion = nn.BCELoss() # nn.KLDivLoss(reduction='batchmean')
self.labels = labels
def configure_optimizers(self):
opt = optim.Adam(self.parameters(), lr=self.hparams.lr)
return opt
# sched = optim.lr_scheduler.StepLR(opt, 200, 0.5)
# return [opt], [sched]
def forward(self, input_ids, attention_mask=None):
logits = self.model(input_ids, attention_mask=attention_mask).logits
if self.multiclass:
logits = logits.softmax(dim=-1)
else:
logits = logits.sigmoid().squeeze(1).float()
return logits
def _shared_step(self, batch):
ids, masks, labels, soft_labels = batch
alpha = self.hparams.alpha
logits = self(ids, masks)
ce_loss = self.criterion(logits, labels)
kd_loss = self.soft_label_criterion(logits, soft_labels)
loss = alpha * ce_loss + (1 - alpha) * kd_loss
return {
"loss": loss,
"logits": logits,
"labels": labels,
"ce_loss": ce_loss,
"kd_loss": kd_loss,
}
def training_step(self, batch, batch_idx):
return self._shared_step(batch)
def validation_step(self, batch, batch_idx):
return self._shared_step(batch)
def _shared_epoch_end(self, outputs, stage):
outputs = join_step_outputs(outputs)
loss_names = ["loss", "ce_loss", "kd_loss"]
for name in loss_names:
loss = outputs[name].mean()
self.log(f"{stage}_epoch_{name}", loss, prog_bar=True)
logits = outputs["logits"]
labels = outputs["labels"]
acc = tm.accuracy(logits, labels.int())
self.log(f"{stage}_acc", acc, prog_bar=True)
def training_epoch_end(self, outputs):
self._shared_epoch_end(outputs, "train")
def validation_epoch_end(self, outputs):
self._shared_epoch_end(outputs, "val")
tokenizer_curse = AutoTokenizer.from_pretrained("monologg/koelectra-small-v3-discriminator")
hate = "./hate.ckpt"
curse = "./curse.ckpt"
curse = TextClassificationStudentModule.load_from_checkpoint(curse, device)
hate = TextClassificationStudentModule.load_from_checkpoint(hate, device)
def detect_curse(text):
with torch.no_grad():
model_input = tokenizer_curse(text, return_tensors="pt")
curse_pred = curse(
model_input["input_ids"], model_input["attention_mask"]
)[0].item()
hate_pred = hate(model_input["input_ids"], model_input["attention_mask"])[
0
].item()
print(curse_pred)
print(hate_pred)
KeyError: 'pytorch-lightning_version'
이런 에러가 나오는데 pytorch lightning 버전이 예전에 개발했을때 사용한 버전과 달라서 생긴 문제
이전에는 1.5.0에서 사용했는데 지금 깔린건 2.4.0이길래 1.5.0으로 재설치
!pip install pytorch_lightning==1.5.0
근데 다시 실행해보니
RuntimeError: Error(s) in loading state_dict for TextClassificationStudentModule: Unexpected key(s) in state_dict: "model.electra.embeddings.position_ids".
이 경우 모델을 로드하는 함수에 strict = False를 설정하면 로드할 수 있다
curse = TextClassificationStudentModule.load_from_checkpoint(curse,device,strict = False)
hate = TextClassificationStudentModule.load_from_checkpoint(hate, device,strict = False)
이 에러가 난 이유는 역시 이전 학습환경과 현재 transformers나 pytorch 라이브러리 등의 버전이 맞지 않아서이다.
근데 strict = False를 하면 일부 weight가 누락되어 예측 성능이 예상보다 낮아질 수 있음
어떻게 보면 임시방편이다
그래서 근본적으로 해결할려면 재학습해야하는듯?.. 허허
'프로그래밍 > Pytorch' 카테고리의 다른 글
Pytorch model forward에서 에러나는 경우 대처하기(input output model shape print해보기) (0) | 2024.08.25 |
---|---|
Pytorch에서 padding sequence vs. packed sequence 차이 이해하고 구현하기 (0) | 2024.04.19 |
Pytorch에서 learning rate scheduler 사용하는 방법 알기 (0) | 2024.04.17 |
torch.where()로 tensor내 특정 원소의 위치를 찾기 (0) | 2024.04.14 |
Pytorch의 computational graph와 backward()에 대해 이해하기 (0) | 2024.04.13 |