Paper : https://arxiv.org/abs/2303.04947

 

InfoBatch: Lossless Training Speed Up by Unbiased Dynamic Data Pruning

Data pruning aims to obtain lossless performances with less overall cost. A common approach is to filter out samples that make less contribution to the training. This could lead to gradient expectation bias compared to the original data. To solve this prob

arxiv.org

GitHub : https://github.com/NUS-HPC-AI-Lab/InfoBatch/tree/master

 

GitHub - NUS-HPC-AI-Lab/InfoBatch: Lossless Training Speed Up by Unbiased Dynamic Data Pruning

Lossless Training Speed Up by Unbiased Dynamic Data Pruning - NUS-HPC-AI-Lab/InfoBatch

github.com

 

InfoBatch는 Data pruning 프레임워크 중 하나로 lossless(무손실) 성능으로 모델을 가속화하여 학습시키는 방법이다.

 

📚 사전 지식

이 논문을 효과적으로 이해하는 데 필요한 사전 지식을 정리했다. 이는 각자 가진 도메인 지식의 차이를 고려하여, 공통된 이해 기반을 마련하기 위함이다.

 

🔹 Data Pruning

우선 Pruning이란 가지치기라고 말하기도 하며 모델에서 중요도 낮은 파라미터를 제거하는 작업을 의미한다. Data Pruning은 학습에 사용되는 데이터의 크기를 줄여 일부만 학습에 참여시키는 것이다. (어떤 데이터를 줄이고 언제 줄일지 등 여러 연구가 있다.) Data Pruning에서 목표는 원본 데이터로 학습한 모델과 Data Pruning을 진행한 데이터로 학습한 모델의 성능이 비슷한 성능을 내는 것을 목표로 한다.

🔹 Overhead

Computer science에서 overhead는 주요 작업을 수행하는 데 직접적으로 필요하지 않지만, 시스템 운영이나 프로그램 실행을 위해 추가로 소요되는 시간, 메모리, 대역폭, 기타 컴퓨팅 자원 등을 의미한다.

 

🔹 EL2N Score

Error L2-Norm Score를 의미하며 학습 샘플 $(x,y)$에 대해서 $\mathbb{E}||p(w_t, x) - y||_2$로 정의할 수 있다.

(시간 t에서 모델의 weight $w_t$와 입력 데이터 x의 예측값 $p(w_t, x)$와 라벨 $y$의 L2-Norm)

해당 값에 따라 학습 샘플 $x$의 난이도를 평가하는데 사용할 수 있다. EL2N Score가 높으면 모델이 해당 샘플을 잘 학습하지 못하는 것을 의미한다.

 

🔹 bias

Bias는 모델을 통해 얻은 예측값과 실제값의 차이의 평균을 나타낸다. 높은 Bias를 가지고 있으면 예측값이 실제값과 차이가 크다는 것을 의미한다.

 

🔹 Static Pruning

Static Pruning은 모델 학습 전에 중요도가 낮은 샘플을 제거하여 훈련 반복을 줄이는 방법이다. 하지만 정확도를 유지하기 위해 여러 번의 시도가 필요해 대규모 데이터셋에서는 오히려 비효율적일 수도 있다.

 

🔹 Dynamic Pruning

Dynamic Pruning은 Static Pruning과 다르게 학습 중에 손실 값 등을 기반으로 샘플을 동적으로 pruning 한다. static pruning 보다는 효율적이지만, 대규모 데이터셋에서 많은 overhead가 발생한다.

 

📏 InfoBatch

이제 본 논문에서 제안하는 방법인 InfoBatch에 대해 알아본다.

InfoBatch는 unbiased dynamic data pruning을 적용해 lossless(무손실) 성능으로 모델을 가속화하여 학습시킬 수 있는 방법이다. InfoBatch는 Classification, Semantic Segmentation, Vision 관련, instruction fine-tuning tasks에서 학습에 드는 전체 cost를 감소시키면서 손실 없이 일관된 학습 결과를 얻을 수 있다.

 

⚔️ 기존 방법과의 비교

아래는 이전의 방법에 대한 Figure이며 하나씩 보면 다음과 같다.

 

EL2N score를 측정하여 학습 샘플들을 Hard Prune시킨다. 이때 EL2N score에 대해 정렬(sorting)을 하기 때문에 $O(logN)$ 만큼의 시간이 소요되는 것을 볼 수 있다. 이후 학습을 진행했을 때 생겨버린 Bias로 인해 성능이 78.2%에서 71.0%로 감소한 것을 볼 수 있다.

 

이러한 문제를 해결하기 위해 논문에서는 InfoBatch를 제안한다.

 

우선 학습을 진행하면서 Loss Value가 작은 값들을 Soft Prune을 하여 일부만 제거하며 학습시킨다. 이때 기존 방법과 다르게 O(1)만큼의 시간이 소요되는 것을 볼 수 있다. 이후에 Soft prune을 하고 남은 부분만 Rescale 하여 원본 데이터셋과 동일한 gradient를 유지한다. 이렇게 학습을 한 결과 unbiased + Lossless의 결과를 얻을 수 있다.

 

 

🔍 세부 사항

InfoBatch에 어떻게 Lossless를 달성할 수 있었는지 더 디테일하게 살펴본다.

 

🔹 Threshold $\bar{\mathcal{H}}_t$ 설정

$\bar{\mathcal{H}}_t$는 샘플에 대한 Loss 값의 평균을 의미한다. 해당 값을 기준으로 $\bar{\mathcal{H}}_t$ 보다 작으면 $D1$, 크면 $D2$의 영역으로 설정된다. 기존 $O(logN)$만큼 걸리던 이유는 정렬을 하기 때문이고, $O(1)$만큼 걸리는 이유는 평균값 $\bar{\mathcal{H}}_t$와 비교만 하면 되기 때문이다.

 

$\bar{\mathcal{H}}_t$ 과 Compare 코드

  • score의 평균과 각 score값을 비교해서 index를 체크한다.
well_learned_mask = (self.scores < self.scores.mean()).numpy()
well_learned_indices = np.where(well_learned_mask)[0]

 

🔹 Soft Pruning

데이터셋 $D$를 $D1$과 $D2$로 나눴으면 $D1$에 대해서만 Soft Pruning을 한다. 이때 hyper-paramter로 pruning 할 비율 $r$ 값을 정의한다. (논문에서는 pruning probabilty라고 적혀있지만 실제 코드에서는 $D1 * r$ 만큼 pruning) Pruning 할 비율만큼 무작위로 pruning을 진행한다. (이렇게 해서 pruning 된 데이터셋을 $D3$라 표현) Score가 낮은 샘플도 학습에 참여시킬 수 있어 bias가 줄어든다.

 

Random Pruning 코드

  • prune_ratio를 기준으로 keep_ratio를 계산하여 얼마만큼 남길지 정한다.
self.keep_ratio = min(1.0, max(1e-1, 1.0 - prune_ratio))

selected_indices = np.random.choice(well_learned_indices, int(
            self.keep_ratio * len(well_learned_indices)), replace=False)

 

🔹 Expectation Rescaling

soft pruning은 학습에 참여하는 샘플을 감소시키기 때문에 Gradient 업데이트 횟수가 감소할 수 밖에 없다. 이를 해결하기 위해 pruning 된 데이터셋 $D3$의 기울기를 원래 기울기의 $1/(1-r)$배로 scale up 한다. 이렇게 하면 원본 데이터셋과 동일한 expectation of gradient를 얻기 때문에 어느 정도 성능차이를 완화한다.

 

Rescaling 코드

  • 아래 코드에서는 1 / self.keep_ratio로 값을 설정해준다. (나중에 이 값을 곱해주는 작업을 진행한다. `values.mul_(weights)` )
if len(selected_indices) > 0:
            self.weights[selected_indices] = 1 / self.keep_ratio
            remained_indices.extend(selected_indices)

 

🔹 Annealing

gradient expectation bias를 줄이기 위해 특정 epoch 이후부터는 전체 데이터셋으로 학습을 시킨다. 특정 epoch는 δ · C로 정의되면 δ 값은 1에 가깝다. (Code에서는 Default 값은 δ=0.875).

다시 말해 200 epoch를 학습시키는 경우에 δ=0.875인 경우 175(200x0.875) epoch까지는 위에서 설명한 방법들을 적용하여 학습시키고 그 이후부터는 원본 데이터셋 그대로 학습을 시킨다.

 

Annealing 코드

  • 전체 epoch에서 δ를 곱해준다. 이 epoch 이후 부터는 pruning을 진행하지 않는다.
def stop_prune(self):
        return self.num_epochs * self.delta

 

if self.iterations > self.stop_prune:
            if self.iterations == self.stop_prune + 1:
                self.dataset.reset_weights()
            self.sample_indices = self.dataset.no_prune()
        else:
            self.sample_indices = self.dataset.prune()

 

📊 실험 결과

여러 Task에서 좋은 성능을 낼 수 있다는 것을 보여주기 위해 CIFAR-10/100, ImageNet-1K, ADE20K, FFHQ에서 효율성을 검증한다.

 

위 표는 Static pruning과 Dynamic pruning SOTA 방법들과 비교한 것이다. Random은 Dynamic random pruning을 적용한 것이다. ResNet-18을 사용하였으며 CIFAR10과 CIFAR100에서 InfoBatch만 모두 Lossless 학습을 할 수 있다.

 

다른 방법들과 비교했을 때 InfoBatch에서 Accuracy가 약간의 상승이 있었으며 Overhead가 가장 적다.

 

여러 Model에서도 거의 Lossless하게 높은 Prune Ratio로 성능을 유지하는 것을 볼 수 있다.

 

CIFAR-100 데이터셋을 사용한 것이며 Res는 Rescaling, Ann은 Annealing을 적용한 것이다. 두 개 모두 적용했을 때 Lossless 한 것을 볼 수 있다.

 

CIFAR-100으로 ResNet-50을 학습시킨 결과이며, 어떤 부분을 pruning 하는지에 따른 성능 차이를 보여준다. \bar{\mathcal{H}}_t 값보다 작은 부분을 pruning 하는 것이 prune을 더 많이 할 수도 있고 성능도 더 잘 유지할 수 있다.

 

 

r와 δ 의 최적의 값을 찾기 위한 실험을 하였으며 위의 Figure에서 볼 수 있듯이 r=0.5, δ=0.875일 때 일반적으로 전체 cost와 성능 부분에서 가장 효과적이었다.

 

 

이 외에도 Segmentation, Diffusion model, LLaMA 모델에서도 좋은 성능을 보인다.

 

🏁 결론

InfoBatch는 unbiased dynamic data pruning을 통해 Lossless 학습을 가속화할 수 있으며 여러 분야에서 쓰일 수 있다. 이전의 방법들에 비해 overhead를 최소 10배 줄일 수 있어 실제 응용에 유용하다.

 

⌛ 한계 및 향후 연구

Sample을 제거하면 모델 예측에 bias가 생길 수 있는데 윤리적으로 민감한 데이터셋에 적용할 때 이러한 제한 사항을 고려해야 한다.

InfoBatch는 다중 epoch에서 훈련해야 하지만 GPT-3 및 ViT-22B는 제한된 epoch에서 훈련을 하기 때문에 InfoBatch의 적용이 달라질 수 있다.

 

 

📃 참고자료

https://blog.vpromise.fun/p/data/

https://en.wikipedia.org/wiki/Overhead_(computing)

https://ganguli-gang.stanford.edu/pdf/21.DataDiet.pdf

https://cristianefragata.medium.com/machine-learning-bias-and-variance-26b6ee572af