Dataset Decomposition: Faster LLM Training with Variable Sequence Length Curriculum
2024.05
Apple Machine Learning Research
1. 요약
대규모 언어 모델(LLM)은 일반적으로 고정 길이의 토큰 시퀀스로 구성된 데이터셋을 사용하여 훈련됩니다.
이러한 데이터셋은 다양한 길이의 문서를 무작위로 연결한 후, 정해진 목표 길이의 시퀀스로 분할하여 생성됩니다.
그러나 이러한 연결 방식은 시퀀스 내에서 문서 간 주의(cross-document attention)가 발생하게 하며, 이는 학습 신호로 적합하지 않을 뿐만 아니라 계산 효율성도 떨어뜨립니다.
또한, 긴 시퀀스에 대한 훈련은 주의 계산의 이차적 비용(quadratic cost) 때문에 계산적으로 부담이 큽니다.
이 연구에서는 이러한 문제를 해결하기 위해 데이터셋 분해(dataset decomposition)라는 새로운 가변 시퀀스 길이 훈련 기법을 소개합니다.
우리는 데이터셋을 고유 문서에서 추출된 동일한 길이의 시퀀스를 포함하는 여러 버킷(bucket)의 합집합으로 분해합니다.
훈련 중에는 모든 버킷에서 커리큘럼을 기반으로 동시에 샘플링하여 가변 시퀀스 길이와 배치 크기를 사용합니다.
Concat-and-chunk 방식의 기존 접근법은 훈련의 매 단계에서 고정된 주의 비용을 발생시키는 반면,
제안된 방법은 각 단계에서 실제 문서 길이에 비례하는 비용만을 발생시켜 훈련 시간을 크게 절약할 수 있습니다.
이를 통해 우리는 기존 접근법으로 2k(2,000) 컨텍스트 길이 모델을 훈련하는 비용으로 8k(8,000) 컨텍스트 길이의 10억 매개변수 모델을 훈련할 수 있었습니다.
웹 스케일 코퍼스(corpus)에서 수행한 실험 결과, 제안된 방법은 표준 언어 평가 및 긴 컨텍스트 벤치마크에서 성능을 크게 향상시키며, 기존 방법 대비 3배 빠르게 목표 정확도에 도달했습니다.
이 방법은 긴 시퀀스에 대한 효율적인 사전 훈련을 가능하게 할 뿐만 아니라 데이터셋 크기와 함께 효과적으로 확장됩니다.
마지막으로, 우리는 대규모 언어 모델 훈련에서 비교적 덜 연구된 중요한 측면, 즉 시퀀스 길이의 분포와 커리큘럼이 성능에 미치는 유의미한 차이에 대해 조명합니다.
---------------------------------------------------------------------------------------------------------------------------------------------------------
2. dataset decomposition
토큰화된 문서 집합 $D = \{d_1, d_2, \dots, d_n\}$에 대해, 데이터셋 분해(DD)의 목표는 $D$를 버킷들의 합집합 $\bigcup_{i} D_i$로 재구성하는 것입니다.
여기서 각 버킷 $D_i$는 다음 조건을 만족합니다
(1) 각 버킷 $D_i$는 길이가 $l_i$인 토큰 시퀀스로 구성된다
(2) 각 시퀀스 $s \in D_i$는 하나의 문서 $d \in D$의 부분 시퀀스이다
(3) $D$의 각 토큰은 정확히 하나의 $D_i$에만 나타난다.
이 분해는 각 시퀀스가 고유한 문서에 속하게 하여 훈련 중 시퀀스 내에서 문서 간 주의(cross-document attention)가 발생하지 않도록 보장합니다.
또한, 주어진 버킷 $D_i$ 내의 모든 시퀀스는 동일한 길이 $l_i$를 가지므로 효율적인 배치 처리가 가능합니다.
위와 같이 정의된 데이터셋 분해는 유일하지 않습니다.
우리는 $l_i = 2^i$로 정의된 특정 분해 방법을 제안하여, 원래의 문서 시퀀스 길이 분포를 최적으로 유지하면서 효율적인 배치 사전 훈련을 가능하게 합니다
우리는 문서 수준에서 분해를 적용하므로, 이 방법을 기존의 데이터 준비 파이프라인(모델 훈련 전에 수행하는 단계)에 쉽게 통합할 수 있으며, 대규모 데이터셋에 확장 가능합니다.
길이가 $l$인 토큰화된 문서 $d \in D$에 대해, $l = 2^{i_1} + 2^{i_2} + \dots + 2^{i_k}$로 나타낼 수 있는 이진 분해를 사용하여, 문서 $d$를 길이가 각각 $2^{i_1}, 2^{i_2}, \dots, 2^{i_k}$인 $k$개의 인접한 시퀀스로 나눕니다.
그런 다음 각 시퀀스 $s_j$ (길이가 $2^{i_j}$)는 버킷 $D_{i_j}$에 할당됩니다.
그림 2는 이 방법의 개념적 표현을 보여줍니다.
제안된 데이터셋 분해 접근법을 통해, 각 버킷 $D_i$는 원본 문서 $d$에서 추출된 시퀀스를 포함하며, 문서 $d$의 길이는 최소한 $2^{i_2}$입니다.
3. Variable sequence length training
우리는 $D_i$가 길이가 $2^i$인 시퀀스를 포함하는 k개의 버킷이 존재한다고 가정합니다.
여기서 $b$는 최적화 단계당 사용되는 토큰의 목표 배치 크기입니다.
가변 시퀀스 길이(VSL) 훈련에서는 매 최적화 단계에서 먼저 사용할 수 있는 선택지 중에서 $i$를 샘플링한 후,
버킷 $D_i$에서 $b/2^i$만큼의 시퀀스를 선택합니다.
$D_i$는 길이가 $2^i$인 시퀀스로 구성되므로, 최적화 단계당 본 시퀀스의 수는 $i$의 선택과 관계없이 항상 $b$로 일정합니다.
VSL 알고리즘을 사용한 LLM 훈련에는 몇 가지 장점이 있습니다.
첫째, 최적화 단계당 본 토큰의 총 수가 변하지 않기 때문에, VSL은 최적화 역학을 변경하지 않으며, 기본 설정에서 사용한 것과 동일한 하이퍼파라미터를 사용할 수 있습니다
둘째, 고정된 $b$(단계당 토큰 수)에 대해 최적화 단계를 완료하는 데 걸리는 시간(순전파+역전파)은 시퀀스 길이에 따라 달라지며, 이는 주의 메커니즘의 제곱 비용에 기인합니다 [59].
VSL 훈련에서는 매 최적화 단계의 비용이 해당 단계에서 샘플링된 버킷 $D_i$에 의존하므로, 더 비싼 단계(긴 시퀀스에 해당)는 더 저렴한 단계(짧은 시퀀스에 해당)로 보완됩니다.
마지막으로, VSL의 샘플링 구성 요소(각 최적화 단계에서 어떤 $D_i$를 선택할지)는 시퀀스 길이에 대한 다른 커리큘럼을 가능하게 합니다. 이러한 커리큘럼이 모델 안정성과 일반화 정확도에 미치는 중요성을 보여줍니다.
4. 평가
우리는 제안한 방법인 데이터 분해를 문서 마스킹(DM), 최적 맞춤 시퀀스 패킹 [17], 문맥 내 사전 훈련(ICLM) [51] 등 다양한 문서 길이를 처리하는 다른 접근 방법들과 비교합니다.
문서 마스킹은 크로스-문서 주의를 방지하여 기본 평가에서 51.5에서 52.4로 향상시킵니다.
그러나 [60]에서는 관련 없는 문서를 이어 붙여서 훈련하는 것이 짧은 시퀀스로만 훈련하는 것보다 긴 문맥 지표를 향상시킬 수 있음을 보여줍니다.
따라서 DM은 긴 문맥 평가에서 약간의 감소를 경험하며, 27.5에서 27.1로 떨어집니다.
최적 맞춤 시퀀스 패킹 [17]은 DM의 문제를 해결하는데, 문서를 더 효율적으로 배열하여 청크화/잘림을 줄이고 각 훈련 시퀀스에서 더 긴 문서 세그먼트를 생성합니다.
이 접근법은 일반적인 평가와 긴 문맥 평가 모두에서 성능을 향상시킵니다.
다른 한편으로, ICLM [51]은 내용 유사도를 기반으로 문서를 정렬하는 방법을 제안합니다.
ICLM이 대규모 Common Crawl 데이터를 사용한 실험에서 일반적인 평가에서는 그 이점이 미미하지만, 유사한 문서들을 이어 붙여 긴 데이터 시퀀스를 구성함으로써 긴 문맥 평가 지표를 27.5에서 28.7로 크게 향상시킵니다.
그러나 ICLM이 제안하는 유사도 찾기 단계는 대규모 데이터에서 자원 소모가 큰 작업입니다.
마지막으로, 우리가 제안한 방법은 관련 없는 콘텐츠에 대한 크로스-문서 주의를 피하고, 일관된 긴 시퀀스를 유지하며, 길이 기반 커리큘럼을 활용하여 성능을 향상시킵니다.
이로 인해 일반적인 평가와 긴 문맥 평가 모두에서 효과적인 성능 개선을 이끌어냅니다.
5. 한계점
기본 방법과 비교한 훈련 속도 향상은 목표 시퀀스 길이가 충분히 길 때만 유의미합니다.
그렇지 않으면 주의 비용이 훈련의 주요 부분을 차지하지 않으므로, 훈련 속도 향상이 기대되지 않습니다.
https://arxiv.org/abs/2405.13226