TPU로 LLM 학습하기 (feat. TRC, MaxText)

태그 TPUTRC

TPU가 무엇인지, 그리고 왜 TPU를 사용해야 하는지에 대해 알아봅니다. MaxText 를 사용해 TPU로 학습을 진행하는 방법과 TPU Reserach Cloud도 함께 살펴봅니다.

TPU v4의 실물 모습. 현재는 TPU v6e까지 출시되었다. (출처: 구글)

chevron_right

목차


TPU란?

TPU는 Google에서 커스텀 개발한 ASIC로서 ML 워크로드를 빠르게 처리하는 데 사용된다.[1]. 요즘 LLM을 학습하는 데 쓰이는 GPU보다 효율적으로 대규모 ML 연산을 처리할 수 있다. 현재 v6e(트릴리움)까지 공개되었다.

실사용 시의 어려움

TPUGPU보다 초기 설정이 복잡하다. PyTorch와 사용하고자 한다면 PyTorch/XLA와 씨름하게 될 것이고, JAX를 사용한다면 그나마 낫겠지만 여전히 까다로운 병렬 처리를 만나게 될 것이다. CUDA 생태계 밖을 나온다면 겪어야 할 난관인 셈이다. 본인의 예를 들면 처음 TPU를 사용할 때는 PyTorch/XLA를 사용해 mamba[2] 모델을 학습하고자 했는데, 수십 분 동안 진행이 되지 않았다. XLA는 데이터 형태가 바뀌면 컴파일을 다시 해야 하는데, 이 때문에 호환되지 않는 코드가 굉장히 많다. 심지어 관련 자료도 부실한 편이다.

그럼에도 불구하고 GPU 대신 TPU를 사용해야 하는 이유는 무엇일까? TPU를 사용해야 할 이유가 바로 다음에 소개할 내용에 있다.

TRC

바로 TPU Research Cloud, TRC다. TRC에 참가하면 상당한 수준의 성능을 가진 TPU와 VM[*1]무료로 일정 기간 동안 사용할 수 있다. 내 경우는:

  • TPU v2-8 50개 (선점형)
  • TPU v3-8 50개 (선점형)
  • TPU v4 32개 (선점형)
  • TPU v4 32개 (온디맨드)

를 처음에 30일동안 제공받았다[*2]. 이를 통해 devngho/ko_edu_classifier_v2_nlpai-lab_KoE5[*3] 등의 모델과 devngho/the-stack-llm-annotations-v2[*4] 등의 데이터셋을 공개했다.

물론 TRC에도 조건은 있다. TRC의 지원을 받은 연구를 동료 검토 출판물, 오픈 소스 코드, 블로그 게시물 등으로 공유해야 한다. 이름부터 Research Cloud인 만큼 연구 결과를 공유하는 것이 당연하다. 이미 수많은 논문들에서 TRC를 통해 TPU를 사용했다[3]. 이 외에도 당장 한국에서 널리 알려진 beomi/llama-2-ko-7bbeomi/kcbert-base 등 모델에 TRC 지원 TPU가 사용되었다.

Shawn이 TRC에 대해 남긴 댓글을 읽어 보면 TRC에 신청하는 데 부담감 갖지 말고 먼저 신청해보길 권하고 있다. TPU 지원 팀 또한 친절하게 도와줄 것이라고 하고 있다[*5]. 나도 이 의견에 완전히 동의한다. 평소 하고 싶지만 비용이나 자원 문제로 못 했던 연구가 있다면 더할 나위 없이 좋은 기회일 것이고, 그렇지 않더라도 일단 신청해보자.

TPULLM 학습하기 (feat. MaxText)

이제 TPU를 사용해야 할 이유와, 어떻게 써볼 수 있는지 소개했다. 이제 TPULLM을 학습하는 방법을 알아보자. HF transformers는 JAX에 대한 지원이 부실하고[*6], PyTorch/XLA로 사용하자기에는 FSDPv2로 학습할 수는 있지만 여러 문제가 있다[*7]. 대신에 우리가 사용할 것은 maxtext다. maxtext는 JAX로 쓰여진 고성능, 고확장성 오픈소스 LLM 코드베이스다. maxtext는 구글에서 만든 만큼 처음부터 TPU를 염두에 두고 만들어졌으며 사용하기 쉽다.

TPU 노드 생성

TPU는 수요가 많아 노드 생성에 오랜 시간이 걸릴 수 있다. 아래처럼 큐에 넣어 놓고 생성을 기다리는 방법이 있다. 경험상 길어도 이틀 정도면 생성된다.

스팟 노드는 보다 짧은 시간 내에 생성되는 편이나, 언제 종료될지 모르기 때문에 주의해야 한다. 기존에 사용되던 선점형 노드는 24시간 제한이 있으므로 스팟 노드를 사용하자. TRC 팀에 문의했을 때 선점형 무료 할당량으로 스팟 노드를 만들 수 있다는 답변을 받았다.

gcloud alpha compute tpus queued-resources create $QUEUED_RESOURCE_ID \
    --node-id $NODE_ID \
    --project $PROJECT_ID \
    --zone $REGION \
    --accelerator-type $NODE_TYPE \
    --runtime-version tpu-ubuntu2204-base \
    --spot \ // 스팟 노드를 만들고자 한다면 추가.
    --scopes=https://www.googleapis.com/auth/devstorage.read_only,https://www.googleapis.com/auth/logging.write,https://www.googleapis.com/auth/monitoring.write,https://www.googleapis.com/auth/pubsub,https://www.googleapis.com/auth/service.management.readonly,https://www.googleapis.com/auth/servicecontrol,https://www.googleapis.com/auth/trace.append,https://www.googleapis.com/auth/devstorage.full_control

TPU 노드 접속

노드가 생성되면 SSH로 접속할 수 있다. 아래 명령어를 통해 접속하자.

gcloud alpha compute tpus tpu-vm ssh $node --zone=$region --project=$PROJECT_ID --worker=0

screen을 띄워 명령어를 실행하려면 아래와 같이 하자. 나는 명령어를 gist에 올려놓고 여러 노드에서 실행했다. gist는 당연하게 private로 만들어야 한다.

gcloud alpha compute tpus tpu-vm ssh $node \
    --zone=$region  \
    --project=$PROJECT_ID \
    --worker=all \
    --command="screen -L -Logfile logfile.txt -d -m bash -c \"bash <(curl -sL $path)\""

설치

먼저 repo를 클론하고 필요한 패키지를 설치한다.

git clone AI-Hypercomputer/maxtext

cd maxtext

pip install -r ./requirements.txt

데이터셋 준비

maxtext에서는 크게 3가지 방법으로 데이터셋을 로딩할 수 있다. HuggingFace, Grain, TFDS다[4]. 여기서는 HuggingFace의 예시를 보자.

# huggingface hub에서 스트리밍하는 예시다.
dataset_type: hf
hf_path: 'allenai/c4'  # for using https://huggingface.co/datasets/allenai/c4
hf_data_dir: 'en'
hf_train_files: ''
# set eval_interval > 0 to use the specified eval dataset, otherwise, only metrics on the train set will be calculated.
eval_interval: 10000
hf_eval_split: 'validation'
hf_eval_files: ''
# for HF pipeline, tokenizer_path can be a path in HuggingFace Hub, 
# or a local path containing tokenizer in a format supported by transformers.AutoTokenizer
tokenizer_path: 'google-t5/t5-large'  # for using https://huggingface.co/google-t5/t5-large
hf_access_token: ''  # provide token if using gated dataset or tokenizer
# GCS에 있는 데이터셋을 사용하는 예시다.
dataset_type: hf
hf_path: 'parquet'  # or json, arrow, etc.
hf_data_dir: ''
hf_train_files: 'gs://<bucket>/<folder>/*-train-*.parquet'   # match the train files
# set eval_interval > 0 to use the specified eval dataset. Otherwise, only metrics on the train set will be calculated.
eval_interval: 10000
hf_eval_split: ''
hf_eval_files: 'gs://<bucket>/<folder>/*-validation-*.parquet'  # match the val files
# for HF pipeline, tokenizer_path can be a path in HuggingFace Hub, 
# or a local path containing tokenizer in a format supported by transformers.AutoTokenizer
tokenizer_path: 'google-t5/t5-large'  # for using https://huggingface.co/google-t5/t5-large

학습 설정

학습에 가장 기초적인 모델 아키텍처 등을 설정해보자. 여기서는 기본적인 설정만 소개한다. 자세한 설정은 MaxText/configs/base.yml에서 모두 확인할 수 있다.

per_device_batch_size: 16 # 기기 당 배치 크기.
gradient_accumulation_steps: 2 # 그래디언트 누적 스텝.
ici_fsdp_parallelism: 16 # FSDP 병렬성. ICI는 한 슬라이스 내에 있는 칩들 간의 네트워크를 의미한다.
ici_tensor_parallelism: 1 # TP 병렬성.
remat_policy: minimal # remat 정책. minimal, full 등이 있다. 적을수록 성능은 좋아지지만 메모리 사용량이 늘어난다.

model_name: llama3-8b # 모델 아키텍처.
steps: 10000 # 총 스텝 수.
checkpoint_period: 1000 # 체크포인트 저장 주기.
eval_interval: 1000 # eval 주기.

learning_rate: 6e-4 # 학습률.

async_checkpointing: true # 체크포인트 비동기 저장.

학습

이제 학습을 시작해보자. 설정을 따로 yaml 파일에 저장해도 좋지만, 여기서는 인라인으로 설정을 넣었다.

python MaxText/train.py MaxText/configs/base.yml \
    base_output_directory=/path/to/output \
    tokenizer_path=google-t5/t5-large \
    per_device_batch_size=16 \
    gradient_accumulation_steps=2 \
    ici_fsdp_parallelism=16 \
    ici_tensor_parallelism=1 \
    remat_policy=minimal \
    async_checkpointing=true \
    model_name=llama3-8b \
    steps=10000 \
    checkpoint_period=1000 \
    eval_interval=1000 \
    learning_rate=6e-4 \
    dataset_type=hf \
    hf_path=allenai/c4 \
    hf_data_dir=en \
    hf_train_files= \
    hf_eval_split=validation \
    hf_eval_files= \
    hf_access_token=YOUR_HF_ACCESS_TOKEN \
    enable_goodput_recording=false \ // IAM에서 서비스 계정에 대한 권한을 부여해야 한다. 나는 그냥 false로 두었다. 
    monitor_goodput=false \ // 위와 같음

학습이 완료되면 MaxText/llama_mistral_mixtral_orbax_to_hf.py와 같은 스크립트를 통해 HF Transformers에서 불러올 수 있는 모델로 변환할 수 있다. 이후 HuggingFace Hub에 업로드할 수도 있다.

위는 가장 기본적인 설정이다. 이 외에도 다양한 설정이 있으니 MaxText/configs/base.yml을 참조하자.

여담

  • maxtext는 당연히 GPU로도 사용할 수 있다.

  • maxtext는 상호 교차 오염 없는 sequence packing도 잘 지원한다! 일반적인 데이터셋에서 2배 이상으로 속도가 향상된다. 기본적으로 활성화되어 있다.

  • maxtext에서는 Flash Attention과 유사한 Splash Attention을 사용할 수 있다.

  • 나는 maxtext를 fork한 devngho/maxtext를 사용하고 있다. mistral nemo 모델 지원과 wandb 지원 등을 추가했다.

  • GCSFuse를 사용해 GCS에 데이터셋과 모델 체크포인트를 저장할 수 있다. 따로 디스크를 붙이는 것보다 저렴하고, 나쁘지 않은 속도를 보여준다.

  • GCP IP 할당량이 부족하다면, TPU를 만들 때 --internal-ips 옵션을 줘서 외부 IP를 사용하지 않고, NAT를 만들어 외부와 통신하게 할 수 있다. SSH 접속 시에는 --tunnel-through-iap 옵션을 주면 된다.

3줄 요약

  • TRC를 통해서 무료로 TPU를 사용, 연구나 개인 프로젝트에 활용해보자.
  • LLM을 학습하려면 maxtext를 사용하자.
  • 구글 💕

각주

  1. [*1] TPU v4의 경우 120core 240thread, 384GB RAM
  2. [*2] TPU v2-8, TPU v3-8은 v3-16 등으로 바꿔 쓸 수 없다. TPU v4는 v4-32처럼 TPU Pod로 묶어 쓸 수 있다.
  3. [*3] 한국어 데이터의 교육성 점수를 평가하는 모델. 자세한 내용은 모델 카드 참조.
  4. [*4] 코드 데이터의 교육성 점수를 Qwen2.5 Coder 32B Instruct 모델로 평가한 데이터셋. 자세한 내용은 데이터셋 카드 참조.
  5. [*5] 내 경우를 소개하자면, 내 질문이 담긴 이메일을 길면 이틀 안에 모두 대답해줄 정도로 정말 최고 수준의 지원이라고 할 수 있다.
  6. [*6] PyTorch로만 구현된 모델도 많고, Flash Attention 지원이 안 된다.
  7. [*7] 해결 방법을 찾기 어려운 문제에 처할 가능성이 높다. 내가 사용했던 때는 체크포인트 저장도 똑바로 되지 않았던 데다가, 입력 shape가 달라지면 계속 재컴파일하기 때문에 디버깅하기 어려운 문제가 일어나기도 한다.

참고 자료

  1. [1]
    Google Cloud, “Cloud TPU 소개,” 2024. [Online]. Available: https://cloud.google.com/tpu/docs/intro-to-tpu?hl=ko. [Accessed 8 12 2024].
    [↑]
  2. [2]
    A. Gu and T. Dao, “Mamba: Linear-time sequence modeling with selective state spaces,” arXiv preprint arXiv:2312.00752, 2023.
    [↑]
  3. [3]
    Google Research, “TPU Research Cloud - Publications,” 2024. [Online]. Available: https://sites.research.google/trc/publications/. [Accessed 8 12 2024].
    [↑]
  4. [4]
    GitHub, “maxtext/getting_started/Data_Input_Pipeline.md at main · AI-Hypercomputer/maxtext,” 2024. [Online]. Available: https://github.com/AI-Hypercomputer/maxtext/blob/main/getting_started/Data_Input_Pipeline.md. [Accessed 9 12 2024].
    [↑]

용어

label
TPU
Tensor Processing Unit
label
ASIC
Application-Specific Integrated Circuit
label
ML
Machine Learning
label
LLM
Large Language Model
label
GPU
Graphics Processing Unit
label
XLA
Accelerated Linear Algebra
label
JAX
Just Another XLA
label
TRC
TPU Research Cloud
label
HF
HuggingFace
label
GCS
Google Cloud Storage

인용하기

BibTeX

@misc{devngho202420241208trc,
  author       = {Yu, Dongho},
  title        = {TPU로 LLM 학습하기 (feat. TRC, MaxText)},
  howpublished = {\url{https://ngho.dev/posts/20241208trc}},
  year         = {2024},
  month        = {dec},
  note         = {Accessed: 2024-12-25}
}

APA 유동호. (2024년 12월 9일). TPU로 LLM 학습하기 (feat. TRC, MaxText). devngho 블로그. https://ngho.dev/posts/20241208trc

Chicago 유동호. “TPU로 LLM 학습하기 (feat. TRC, MaxText).” devngho 블로그, 2024년 12월 9일, https://ngho.dev/posts/20241208trc.

MLA 유동호. “TPU로 LLM 학습하기 (feat. TRC, MaxText).” devngho 블로그, 2024년 12월 9일, https://ngho.dev/posts/20241208trc.