[Paper Review] InfoBatch: Lossless Training Speed Up by Unbiased Dynamic Data Pruning
Paper : https://arxiv.org/abs/2303.04947
GitHub : https://github.com/NUS-HPC-AI-Lab/InfoBatch/tree/master
InfoBatch는 Data pruning 프레임워크 중 하나로 lossless(무손실) 성능으로 모델을 가속화하여 학습시키는 방법이다.
📚 사전 지식
이 논문을 효과적으로 이해하는 데 필요한 사전 지식을 정리했다. 이는 각자 가진 도메인 지식의 차이를 고려하여, 공통된 이해 기반을 마련하기 위함이다.
🔹 Data Pruning
우선 Pruning이란 가지치기라고 말하기도 하며 모델에서 중요도 낮은 파라미터를 제거하는 작업을 의미한다. Data Pruning은 학습에 사용되는 데이터의 크기를 줄여 일부만 학습에 참여시키는 것이다. (어떤 데이터를 줄이고 언제 줄일지 등 여러 연구가 있다.) Data Pruning에서 목표는 원본 데이터로 학습한 모델과 Data Pruning을 진행한 데이터로 학습한 모델의 성능이 비슷한 성능을 내는 것을 목표로 한다.
🔹 Overhead
Computer science에서 overhead는 주요 작업을 수행하는 데 직접적으로 필요하지 않지만, 시스템 운영이나 프로그램 실행을 위해 추가로 소요되는 시간, 메모리, 대역폭, 기타 컴퓨팅 자원 등을 의미한다.
🔹 EL2N Score
Error L2-Norm Score를 의미하며 학습 샘플 $(x,y)$에 대해서 $\mathbb{E}||p(w_t, x) - y||_2$로 정의할 수 있다.
(시간 t에서 모델의 weight $w_t$와 입력 데이터 x의 예측값 $p(w_t, x)$와 라벨 $y$의 L2-Norm)
해당 값에 따라 학습 샘플 $x$의 난이도를 평가하는데 사용할 수 있다. EL2N Score가 높으면 모델이 해당 샘플을 잘 학습하지 못하는 것을 의미한다.
🔹 bias
Bias는 모델을 통해 얻은 예측값과 실제값의 차이의 평균을 나타낸다. 높은 Bias를 가지고 있으면 예측값이 실제값과 차이가 크다는 것을 의미한다.
🔹 Static Pruning
Static Pruning은 모델 학습 전에 중요도가 낮은 샘플을 제거하여 훈련 반복을 줄이는 방법이다. 하지만 정확도를 유지하기 위해 여러 번의 시도가 필요해 대규모 데이터셋에서는 오히려 비효율적일 수도 있다.
🔹 Dynamic Pruning
Dynamic Pruning은 Static Pruning과 다르게 학습 중에 손실 값 등을 기반으로 샘플을 동적으로 pruning 한다. static pruning 보다는 효율적이지만, 대규모 데이터셋에서 많은 overhead가 발생한다.
📏 InfoBatch
이제 본 논문에서 제안하는 방법인 InfoBatch에 대해 알아본다.
InfoBatch는 unbiased dynamic data pruning을 적용해 lossless(무손실) 성능으로 모델을 가속화하여 학습시킬 수 있는 방법이다. InfoBatch는 Classification, Semantic Segmentation, Vision 관련, instruction fine-tuning tasks에서 학습에 드는 전체 cost를 감소시키면서 손실 없이 일관된 학습 결과를 얻을 수 있다.
⚔️ 기존 방법과의 비교
아래는 이전의 방법에 대한 Figure이며 하나씩 보면 다음과 같다.
EL2N score를 측정하여 학습 샘플들을 Hard Prune시킨다. 이때 EL2N score에 대해 정렬(sorting)을 하기 때문에 $O(logN)$ 만큼의 시간이 소요되는 것을 볼 수 있다. 이후 학습을 진행했을 때 생겨버린 Bias로 인해 성능이 78.2%에서 71.0%로 감소한 것을 볼 수 있다.
이러한 문제를 해결하기 위해 논문에서는 InfoBatch를 제안한다.
우선 학습을 진행하면서 Loss Value가 작은 값들을 Soft Prune을 하여 일부만 제거하며 학습시킨다. 이때 기존 방법과 다르게 O(1)만큼의 시간이 소요되는 것을 볼 수 있다. 이후에 Soft prune을 하고 남은 부분만 Rescale 하여 원본 데이터셋과 동일한 gradient를 유지한다. 이렇게 학습을 한 결과 unbiased + Lossless의 결과를 얻을 수 있다.
🔍 세부 사항
InfoBatch에 어떻게 Lossless를 달성할 수 있었는지 더 디테일하게 살펴본다.
🔹 Threshold $\bar{\mathcal{H}}_t$ 설정
$\bar{\mathcal{H}}_t$는 샘플에 대한 Loss 값의 평균을 의미한다. 해당 값을 기준으로 $\bar{\mathcal{H}}_t$ 보다 작으면 $D1$, 크면 $D2$의 영역으로 설정된다. 기존 $O(logN)$만큼 걸리던 이유는 정렬을 하기 때문이고, $O(1)$만큼 걸리는 이유는 평균값 $\bar{\mathcal{H}}_t$와 비교만 하면 되기 때문이다.
$\bar{\mathcal{H}}_t$ 과 Compare 코드
- score의 평균과 각 score값을 비교해서 index를 체크한다.
well_learned_mask = (self.scores < self.scores.mean()).numpy()
well_learned_indices = np.where(well_learned_mask)[0]
🔹 Soft Pruning
데이터셋 $D$를 $D1$과 $D2$로 나눴으면 $D1$에 대해서만 Soft Pruning을 한다. 이때 hyper-paramter로 pruning 할 비율 $r$ 값을 정의한다. (논문에서는 pruning probabilty라고 적혀있지만 실제 코드에서는 $D1 * r$ 만큼 pruning) Pruning 할 비율만큼 무작위로 pruning을 진행한다. (이렇게 해서 pruning 된 데이터셋을 $D3$라 표현) Score가 낮은 샘플도 학습에 참여시킬 수 있어 bias가 줄어든다.
Random Pruning 코드
- prune_ratio를 기준으로 keep_ratio를 계산하여 얼마만큼 남길지 정한다.
self.keep_ratio = min(1.0, max(1e-1, 1.0 - prune_ratio))
selected_indices = np.random.choice(well_learned_indices, int(
self.keep_ratio * len(well_learned_indices)), replace=False)
🔹 Expectation Rescaling
soft pruning은 학습에 참여하는 샘플을 감소시키기 때문에 Gradient 업데이트 횟수가 감소할 수 밖에 없다. 이를 해결하기 위해 pruning 된 데이터셋 $D3$의 기울기를 원래 기울기의 $1/(1-r)$배로 scale up 한다. 이렇게 하면 원본 데이터셋과 동일한 expectation of gradient를 얻기 때문에 어느 정도 성능차이를 완화한다.
Rescaling 코드
- 아래 코드에서는 1 / self.keep_ratio로 값을 설정해준다. (나중에 이 값을 곱해주는 작업을 진행한다. `values.mul_(weights)` )
if len(selected_indices) > 0:
self.weights[selected_indices] = 1 / self.keep_ratio
remained_indices.extend(selected_indices)
🔹 Annealing
gradient expectation bias를 줄이기 위해 특정 epoch 이후부터는 전체 데이터셋으로 학습을 시킨다. 특정 epoch는 δ · C로 정의되면 δ 값은 1에 가깝다. (Code에서는 Default 값은 δ=0.875).
다시 말해 200 epoch를 학습시키는 경우에 δ=0.875인 경우 175(200x0.875) epoch까지는 위에서 설명한 방법들을 적용하여 학습시키고 그 이후부터는 원본 데이터셋 그대로 학습을 시킨다.
Annealing 코드
- 전체 epoch에서 δ를 곱해준다. 이 epoch 이후 부터는 pruning을 진행하지 않는다.
def stop_prune(self):
return self.num_epochs * self.delta
if self.iterations > self.stop_prune:
if self.iterations == self.stop_prune + 1:
self.dataset.reset_weights()
self.sample_indices = self.dataset.no_prune()
else:
self.sample_indices = self.dataset.prune()
📊 실험 결과
여러 Task에서 좋은 성능을 낼 수 있다는 것을 보여주기 위해 CIFAR-10/100, ImageNet-1K, ADE20K, FFHQ에서 효율성을 검증한다.
위 표는 Static pruning과 Dynamic pruning SOTA 방법들과 비교한 것이다. Random은 Dynamic random pruning을 적용한 것이다. ResNet-18을 사용하였으며 CIFAR10과 CIFAR100에서 InfoBatch만 모두 Lossless 학습을 할 수 있다.
다른 방법들과 비교했을 때 InfoBatch에서 Accuracy가 약간의 상승이 있었으며 Overhead가 가장 적다.
여러 Model에서도 거의 Lossless하게 높은 Prune Ratio로 성능을 유지하는 것을 볼 수 있다.
CIFAR-100 데이터셋을 사용한 것이며 Res는 Rescaling, Ann은 Annealing을 적용한 것이다. 두 개 모두 적용했을 때 Lossless 한 것을 볼 수 있다.
CIFAR-100으로 ResNet-50을 학습시킨 결과이며, 어떤 부분을 pruning 하는지에 따른 성능 차이를 보여준다. \bar{\mathcal{H}}_t 값보다 작은 부분을 pruning 하는 것이 prune을 더 많이 할 수도 있고 성능도 더 잘 유지할 수 있다.
r와 δ 의 최적의 값을 찾기 위한 실험을 하였으며 위의 Figure에서 볼 수 있듯이 r=0.5, δ=0.875일 때 일반적으로 전체 cost와 성능 부분에서 가장 효과적이었다.
이 외에도 Segmentation, Diffusion model, LLaMA 모델에서도 좋은 성능을 보인다.
🏁 결론
InfoBatch는 unbiased dynamic data pruning을 통해 Lossless 학습을 가속화할 수 있으며 여러 분야에서 쓰일 수 있다. 이전의 방법들에 비해 overhead를 최소 10배 줄일 수 있어 실제 응용에 유용하다.
⌛ 한계 및 향후 연구
Sample을 제거하면 모델 예측에 bias가 생길 수 있는데 윤리적으로 민감한 데이터셋에 적용할 때 이러한 제한 사항을 고려해야 한다.
InfoBatch는 다중 epoch에서 훈련해야 하지만 GPT-3 및 ViT-22B는 제한된 epoch에서 훈련을 하기 때문에 InfoBatch의 적용이 달라질 수 있다.
📃 참고자료
https://blog.vpromise.fun/p/data/
https://en.wikipedia.org/wiki/Overhead_(computing)
https://ganguli-gang.stanford.edu/pdf/21.DataDiet.pdf
https://cristianefragata.medium.com/machine-learning-bias-and-variance-26b6ee572af