TPU๋กœ LLM ํ•™์Šตํ•˜๊ธฐ (feat. TRC, MaxText)

ํƒœ๊ทธ TPUTRC ๋ถ„๋ฅ˜ ์ผ๋ฐ˜

TPU๊ฐ€ ๋ฌด์—‡์ธ์ง€, ๊ทธ๋ฆฌ๊ณ  ์™œ TPU๋ฅผ ์‚ฌ์šฉํ•ด์•ผ ํ•˜๋Š”์ง€์— ๋Œ€ํ•ด ์•Œ์•„๋ด…๋‹ˆ๋‹ค. MaxText ๋ฅผ ์‚ฌ์šฉํ•ด TPU๋กœ ํ•™์Šต์„ ์ง„ํ–‰ํ•˜๋Š” ๋ฐฉ๋ฒ•๊ณผ TPU Reserach Cloud๋„ ํ•จ๊ป˜ ์‚ดํŽด๋ด…๋‹ˆ๋‹ค.

TPU v4์˜ ์‹ค๋ฌผ ๋ชจ์Šต. ํ˜„์žฌ๋Š” TPU v6e๊นŒ์ง€ ์ถœ์‹œ๋˜์—ˆ๋‹ค. (์ถœ์ฒ˜: ๊ตฌ๊ธ€)

chevron_right

๋ชฉ์ฐจ


TPU๋ž€?

TPU๋Š” Google์—์„œ ์ปค์Šคํ…€ ๊ฐœ๋ฐœํ•œ ASIC๋กœ์„œ ML ์›Œํฌ๋กœ๋“œ๋ฅผ ๋น ๋ฅด๊ฒŒ ์ฒ˜๋ฆฌํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋œ๋‹ค.[1]. ์š”์ฆ˜ LLM์„ ํ•™์Šตํ•˜๋Š” ๋ฐ ์“ฐ์ด๋Š” GPU๋ณด๋‹ค ํšจ์œจ์ ์œผ๋กœ ๋Œ€๊ทœ๋ชจ ML ์—ฐ์‚ฐ์„ ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ๋‹ค. ํ˜„์žฌ v6e(ํŠธ๋ฆด๋ฆฌ์›€)๊นŒ์ง€ ๊ณต๊ฐœ๋˜์—ˆ๋‹ค.

์‹ค์‚ฌ์šฉ ์‹œ์˜ ์–ด๋ ค์›€

TPU๋Š” GPU๋ณด๋‹ค ์ดˆ๊ธฐ ์„ค์ •์ด ๋ณต์žกํ•˜๋‹ค. PyTorch์™€ ์‚ฌ์šฉํ•˜๊ณ ์ž ํ•œ๋‹ค๋ฉด PyTorch/XLA์™€ ์”จ๋ฆ„ํ•˜๊ฒŒ ๋  ๊ฒƒ์ด๊ณ , JAX๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค๋ฉด ๊ทธ๋‚˜๋งˆ ๋‚ซ๊ฒ ์ง€๋งŒ ์—ฌ์ „ํžˆ ๊นŒ๋‹ค๋กœ์šด ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋ฅผ ๋งŒ๋‚˜๊ฒŒ ๋  ๊ฒƒ์ด๋‹ค. CUDA ์ƒํƒœ๊ณ„ ๋ฐ–์„ ๋‚˜์˜จ๋‹ค๋ฉด ๊ฒช์–ด์•ผ ํ•  ๋‚œ๊ด€์ธ ์…ˆ์ด๋‹ค. ๋ณธ์ธ์˜ ์˜ˆ๋ฅผ ๋“ค๋ฉด ์ฒ˜์Œ TPU๋ฅผ ์‚ฌ์šฉํ•  ๋•Œ๋Š” PyTorch/XLA๋ฅผ ์‚ฌ์šฉํ•ด mamba[2] ๋ชจ๋ธ์„ ํ•™์Šตํ•˜๊ณ ์ž ํ–ˆ๋Š”๋ฐ, ์ˆ˜์‹ญ ๋ถ„ ๋™์•ˆ ์ง„ํ–‰์ด ๋˜์ง€ ์•Š์•˜๋‹ค. XLA๋Š” ๋ฐ์ดํ„ฐ ํ˜•ํƒœ๊ฐ€ ๋ฐ”๋€Œ๋ฉด ์ปดํŒŒ์ผ์„ ๋‹ค์‹œ ํ•ด์•ผ ํ•˜๋Š”๋ฐ, ์ด ๋•Œ๋ฌธ์— ํ˜ธํ™˜๋˜์ง€ ์•Š๋Š” ์ฝ”๋“œ๊ฐ€ ๊ต‰์žฅํžˆ ๋งŽ๋‹ค. ์‹ฌ์ง€์–ด ๊ด€๋ จ ์ž๋ฃŒ๋„ ๋ถ€์‹คํ•œ ํŽธ์ด๋‹ค.

๊ทธ๋Ÿผ์—๋„ ๋ถˆ๊ตฌํ•˜๊ณ  GPU ๋Œ€์‹  TPU๋ฅผ ์‚ฌ์šฉํ•ด์•ผ ํ•˜๋Š” ์ด์œ ๋Š” ๋ฌด์—‡์ผ๊นŒ? TPU๋ฅผ ์‚ฌ์šฉํ•ด์•ผ ํ•  ์ด์œ ๊ฐ€ ๋ฐ”๋กœ ๋‹ค์Œ์— ์†Œ๊ฐœํ•  ๋‚ด์šฉ์— ์žˆ๋‹ค.

TRC

๋ฐ”๋กœ TPU Research Cloud, TRC๋‹ค. TRC์— ์ฐธ๊ฐ€ํ•˜๋ฉด ์ƒ๋‹นํ•œ ์ˆ˜์ค€์˜ ์„ฑ๋Šฅ์„ ๊ฐ€์ง„ TPU์™€ VM[*1]์„ ๋ฌด๋ฃŒ๋กœ ์ผ์ • ๊ธฐ๊ฐ„ ๋™์•ˆ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค. ๋‚ด ๊ฒฝ์šฐ๋Š”:

  • TPU v2-8 50๊ฐœ (์„ ์ ํ˜•)
  • TPU v3-8 50๊ฐœ (์„ ์ ํ˜•)
  • TPU v4 32๊ฐœ (์„ ์ ํ˜•)
  • TPU v4 32๊ฐœ (์˜จ๋””๋งจ๋“œ)

๋ฅผ ์ฒ˜์Œ์— 30์ผ๋™์•ˆ ์ œ๊ณต๋ฐ›์•˜๋‹ค[*2]. ์ด๋ฅผ ํ†ตํ•ด devngho/ko_edu_classifier_v2_nlpai-lab_KoE5[*3] ๋“ฑ์˜ ๋ชจ๋ธ๊ณผ devngho/the-stack-llm-annotations-v2[*4] ๋“ฑ์˜ ๋ฐ์ดํ„ฐ์…‹์„ ๊ณต๊ฐœํ–ˆ๋‹ค.

๋ฌผ๋ก  TRC์—๋„ ์กฐ๊ฑด์€ ์žˆ๋‹ค. TRC์˜ ์ง€์›์„ ๋ฐ›์€ ์—ฐ๊ตฌ๋ฅผ ๋™๋ฃŒ ๊ฒ€ํ†  ์ถœํŒ๋ฌผ, ์˜คํ”ˆ ์†Œ์Šค ์ฝ”๋“œ, ๋ธ”๋กœ๊ทธ ๊ฒŒ์‹œ๋ฌผ ๋“ฑ์œผ๋กœ ๊ณต์œ ํ•ด์•ผ ํ•œ๋‹ค. ์ด๋ฆ„๋ถ€ํ„ฐ Research Cloud์ธ ๋งŒํผ ์—ฐ๊ตฌ ๊ฒฐ๊ณผ๋ฅผ ๊ณต์œ ํ•˜๋Š” ๊ฒƒ์ด ๋‹น์—ฐํ•˜๋‹ค. ์ด๋ฏธ ์ˆ˜๋งŽ์€ ๋…ผ๋ฌธ๋“ค์—์„œ TRC๋ฅผ ํ†ตํ•ด TPU๋ฅผ ์‚ฌ์šฉํ–ˆ๋‹ค[3]. ์ด ์™ธ์—๋„ ๋‹น์žฅ ํ•œ๊ตญ์—์„œ ๋„๋ฆฌ ์•Œ๋ ค์ง„ beomi/llama-2-ko-7b๋‚˜ beomi/kcbert-base ๋“ฑ ๋ชจ๋ธ์— TRC ์ง€์› TPU๊ฐ€ ์‚ฌ์šฉ๋˜์—ˆ๋‹ค.

Shawn์ด TRC์— ๋Œ€ํ•ด ๋‚จ๊ธด ๋Œ“๊ธ€์„ ์ฝ์–ด ๋ณด๋ฉด TRC์— ์‹ ์ฒญํ•˜๋Š” ๋ฐ ๋ถ€๋‹ด๊ฐ ๊ฐ–์ง€ ๋ง๊ณ  ๋จผ์ € ์‹ ์ฒญํ•ด๋ณด๊ธธ ๊ถŒํ•˜๊ณ  ์žˆ๋‹ค. TPU ์ง€์› ํŒ€ ๋˜ํ•œ ์นœ์ ˆํ•˜๊ฒŒ ๋„์™€์ค„ ๊ฒƒ์ด๋ผ๊ณ  ํ•˜๊ณ  ์žˆ๋‹ค[*5]. ๋‚˜๋„ ์ด ์˜๊ฒฌ์— ์™„์ „ํžˆ ๋™์˜ํ•œ๋‹ค. ํ‰์†Œ ํ•˜๊ณ  ์‹ถ์ง€๋งŒ ๋น„์šฉ์ด๋‚˜ ์ž์› ๋ฌธ์ œ๋กœ ๋ชป ํ–ˆ๋˜ ์—ฐ๊ตฌ๊ฐ€ ์žˆ๋‹ค๋ฉด ๋”ํ•  ๋‚˜์œ„ ์—†์ด ์ข‹์€ ๊ธฐํšŒ์ผ ๊ฒƒ์ด๊ณ , ๊ทธ๋ ‡์ง€ ์•Š๋”๋ผ๋„ ์ผ๋‹จ ์‹ ์ฒญํ•ด๋ณด์ž.

TPU๋กœ LLM ํ•™์Šตํ•˜๊ธฐ (feat. MaxText)

์ด์ œ TPU๋ฅผ ์‚ฌ์šฉํ•ด์•ผ ํ•  ์ด์œ ์™€, ์–ด๋–ป๊ฒŒ ์จ๋ณผ ์ˆ˜ ์žˆ๋Š”์ง€ ์†Œ๊ฐœํ–ˆ๋‹ค. ์ด์ œ TPU๋กœ LLM์„ ํ•™์Šตํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์•Œ์•„๋ณด์ž. HF transformers๋Š” JAX์— ๋Œ€ํ•œ ์ง€์›์ด ๋ถ€์‹คํ•˜๊ณ [*6], PyTorch/XLA๋กœ ์‚ฌ์šฉํ•˜์ž๊ธฐ์—๋Š” FSDPv2๋กœ ํ•™์Šตํ•  ์ˆ˜๋Š” ์žˆ์ง€๋งŒ ์—ฌ๋Ÿฌ ๋ฌธ์ œ๊ฐ€ ์žˆ๋‹ค[*7]. ๋Œ€์‹ ์— ์šฐ๋ฆฌ๊ฐ€ ์‚ฌ์šฉํ•  ๊ฒƒ์€ maxtext๋‹ค. maxtext๋Š” JAX๋กœ ์“ฐ์—ฌ์ง„ ๊ณ ์„ฑ๋Šฅ, ๊ณ ํ™•์žฅ์„ฑ ์˜คํ”ˆ์†Œ์Šค LLM ์ฝ”๋“œ๋ฒ ์ด์Šค๋‹ค. maxtext๋Š” ๊ตฌ๊ธ€์—์„œ ๋งŒ๋“  ๋งŒํผ ์ฒ˜์Œ๋ถ€ํ„ฐ TPU๋ฅผ ์—ผ๋‘์— ๋‘๊ณ  ๋งŒ๋“ค์–ด์กŒ์œผ๋ฉฐ ์‚ฌ์šฉํ•˜๊ธฐ ์‰ฝ๋‹ค.

TPU ๋…ธ๋“œ ์ƒ์„ฑ

TPU๋Š” ์ˆ˜์š”๊ฐ€ ๋งŽ์•„ ๋…ธ๋“œ ์ƒ์„ฑ์— ์˜ค๋žœ ์‹œ๊ฐ„์ด ๊ฑธ๋ฆด ์ˆ˜ ์žˆ๋‹ค. ์•„๋ž˜์ฒ˜๋Ÿผ ํ์— ๋„ฃ์–ด ๋†“๊ณ  ์ƒ์„ฑ์„ ๊ธฐ๋‹ค๋ฆฌ๋Š” ๋ฐฉ๋ฒ•์ด ์žˆ๋‹ค. ๊ฒฝํ—˜์ƒ ๊ธธ์–ด๋„ ์ดํ‹€ ์ •๋„๋ฉด ์ƒ์„ฑ๋œ๋‹ค.

์ŠคํŒŸ ๋…ธ๋“œ๋Š” ๋ณด๋‹ค ์งง์€ ์‹œ๊ฐ„ ๋‚ด์— ์ƒ์„ฑ๋˜๋Š” ํŽธ์ด๋‚˜, ์–ธ์ œ ์ข…๋ฃŒ๋ ์ง€ ๋ชจ๋ฅด๊ธฐ ๋•Œ๋ฌธ์— ์ฃผ์˜ํ•ด์•ผ ํ•œ๋‹ค. ๊ธฐ์กด์— ์‚ฌ์šฉ๋˜๋˜ ์„ ์ ํ˜• ๋…ธ๋“œ๋Š” 24์‹œ๊ฐ„ ์ œํ•œ์ด ์žˆ์œผ๋ฏ€๋กœ ์ŠคํŒŸ ๋…ธ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์ž. TRC ํŒ€์— ๋ฌธ์˜ํ–ˆ์„ ๋•Œ ์„ ์ ํ˜• ๋ฌด๋ฃŒ ํ• ๋‹น๋Ÿ‰์œผ๋กœ ์ŠคํŒŸ ๋…ธ๋“œ๋ฅผ ๋งŒ๋“ค ์ˆ˜ ์žˆ๋‹ค๋Š” ๋‹ต๋ณ€์„ ๋ฐ›์•˜๋‹ค.

gcloud alpha compute tpus queued-resources create $QUEUED_RESOURCE_ID \
    --node-id $NODE_ID \
    --project $PROJECT_ID \
    --zone $REGION \
    --accelerator-type $NODE_TYPE \
    --runtime-version tpu-ubuntu2204-base \
    --spot \ // ์ŠคํŒŸ ๋…ธ๋“œ๋ฅผ ๋งŒ๋“ค๊ณ ์ž ํ•œ๋‹ค๋ฉด ์ถ”๊ฐ€.
    --scopes=https://www.googleapis.com/auth/devstorage.read_only,https://www.googleapis.com/auth/logging.write,https://www.googleapis.com/auth/monitoring.write,https://www.googleapis.com/auth/pubsub,https://www.googleapis.com/auth/service.management.readonly,https://www.googleapis.com/auth/servicecontrol,https://www.googleapis.com/auth/trace.append,https://www.googleapis.com/auth/devstorage.full_control

TPU ๋…ธ๋“œ ์ ‘์†

๋…ธ๋“œ๊ฐ€ ์ƒ์„ฑ๋˜๋ฉด SSH๋กœ ์ ‘์†ํ•  ์ˆ˜ ์žˆ๋‹ค. ์•„๋ž˜ ๋ช…๋ น์–ด๋ฅผ ํ†ตํ•ด ์ ‘์†ํ•˜์ž.

gcloud alpha compute tpus tpu-vm ssh $node --zone=$region --project=$PROJECT_ID --worker=0

screen์„ ๋„์›Œ ๋ช…๋ น์–ด๋ฅผ ์‹คํ–‰ํ•˜๋ ค๋ฉด ์•„๋ž˜์™€ ๊ฐ™์ด ํ•˜์ž. ๋‚˜๋Š” ๋ช…๋ น์–ด๋ฅผ gist์— ์˜ฌ๋ ค๋†“๊ณ  ์—ฌ๋Ÿฌ ๋…ธ๋“œ์—์„œ ์‹คํ–‰ํ–ˆ๋‹ค. gist๋Š” ๋‹น์—ฐํ•˜๊ฒŒ private๋กœ ๋งŒ๋“ค์–ด์•ผ ํ•œ๋‹ค.

gcloud alpha compute tpus tpu-vm ssh $node \
    --zone=$region  \
    --project=$PROJECT_ID \
    --worker=all \
    --command="screen -L -Logfile logfile.txt -d -m bash -c \"bash <(curl -sL $path)\""

์„ค์น˜

๋จผ์ € repo๋ฅผ ํด๋ก ํ•˜๊ณ  ํ•„์š”ํ•œ ํŒจํ‚ค์ง€๋ฅผ ์„ค์น˜ํ•œ๋‹ค.

git clone AI-Hypercomputer/maxtext

cd maxtext

pip install -r ./requirements.txt

๋ฐ์ดํ„ฐ์…‹ ์ค€๋น„

maxtext์—์„œ๋Š” ํฌ๊ฒŒ 3๊ฐ€์ง€ ๋ฐฉ๋ฒ•์œผ๋กœ ๋ฐ์ดํ„ฐ์…‹์„ ๋กœ๋”ฉํ•  ์ˆ˜ ์žˆ๋‹ค. HuggingFace, Grain, TFDS๋‹ค[4]. ์—ฌ๊ธฐ์„œ๋Š” HuggingFace์˜ ์˜ˆ์‹œ๋ฅผ ๋ณด์ž.

# huggingface hub์—์„œ ์ŠคํŠธ๋ฆฌ๋ฐํ•˜๋Š” ์˜ˆ์‹œ๋‹ค.
dataset_type: hf
hf_path: 'allenai/c4' # for using https://huggingface.co/datasets/allenai/c4
hf_data_dir: 'en'
hf_train_files: ''
# set eval_interval > 0 to use the specified eval dataset, otherwise, only metrics on the train set will be calculated.
eval_interval: 10000
hf_eval_split: 'validation'
hf_eval_files: ''
# for HF pipeline, tokenizer_path can be a path in HuggingFace Hub,
# or a local path containing tokenizer in a format supported by transformers.AutoTokenizer
tokenizer_path: 'google-t5/t5-large' # for using https://huggingface.co/google-t5/t5-large
hf_access_token: '' # provide token if using gated dataset or tokenizer
# GCS์— ์žˆ๋Š” ๋ฐ์ดํ„ฐ์…‹์„ ์‚ฌ์šฉํ•˜๋Š” ์˜ˆ์‹œ๋‹ค.
dataset_type: hf
hf_path: 'parquet' # or json, arrow, etc.
hf_data_dir: ''
hf_train_files: 'gs://<bucket>/<folder>/*-train-*.parquet' # match the train files
# set eval_interval > 0 to use the specified eval dataset. Otherwise, only metrics on the train set will be calculated.
eval_interval: 10000
hf_eval_split: ''
hf_eval_files: 'gs://<bucket>/<folder>/*-validation-*.parquet' # match the val files
# for HF pipeline, tokenizer_path can be a path in HuggingFace Hub,
# or a local path containing tokenizer in a format supported by transformers.AutoTokenizer
tokenizer_path: 'google-t5/t5-large' # for using https://huggingface.co/google-t5/t5-large

ํ•™์Šต ์„ค์ •

ํ•™์Šต์— ๊ฐ€์žฅ ๊ธฐ์ดˆ์ ์ธ ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜ ๋“ฑ์„ ์„ค์ •ํ•ด๋ณด์ž. ์—ฌ๊ธฐ์„œ๋Š” ๊ธฐ๋ณธ์ ์ธ ์„ค์ •๋งŒ ์†Œ๊ฐœํ•œ๋‹ค. ์ž์„ธํ•œ ์„ค์ •์€ MaxText/configs/base.yml์—์„œ ๋ชจ๋‘ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.

per_device_batch_size: 16 # ๊ธฐ๊ธฐ ๋‹น ๋ฐฐ์น˜ ํฌ๊ธฐ.
gradient_accumulation_steps: 2 # ๊ทธ๋ž˜๋””์–ธํŠธ ๋ˆ„์  ์Šคํ….
ici_fsdp_parallelism: 16 # FSDP ๋ณ‘๋ ฌ์„ฑ. ICI๋Š” ํ•œ ์Šฌ๋ผ์ด์Šค ๋‚ด์— ์žˆ๋Š” ์นฉ๋“ค ๊ฐ„์˜ ๋„คํŠธ์›Œํฌ๋ฅผ ์˜๋ฏธํ•œ๋‹ค.
ici_tensor_parallelism: 1 # TP ๋ณ‘๋ ฌ์„ฑ.
remat_policy: minimal # remat ์ •์ฑ…. minimal, full ๋“ฑ์ด ์žˆ๋‹ค. ์ ์„์ˆ˜๋ก ์„ฑ๋Šฅ์€ ์ข‹์•„์ง€์ง€๋งŒ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์ด ๋Š˜์–ด๋‚œ๋‹ค.

model_name: llama3-8b # ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜.
steps: 10000 # ์ด ์Šคํ… ์ˆ˜.
checkpoint_period: 1000 # ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ ์ฃผ๊ธฐ.
eval_interval: 1000 # eval ์ฃผ๊ธฐ.

learning_rate: 6e-4 # ํ•™์Šต๋ฅ .

async_checkpointing: true # ์ฒดํฌํฌ์ธํŠธ ๋น„๋™๊ธฐ ์ €์žฅ.

ํ•™์Šต

์ด์ œ ํ•™์Šต์„ ์‹œ์ž‘ํ•ด๋ณด์ž. ์„ค์ •์„ ๋”ฐ๋กœ yaml ํŒŒ์ผ์— ์ €์žฅํ•ด๋„ ์ข‹์ง€๋งŒ, ์—ฌ๊ธฐ์„œ๋Š” ์ธ๋ผ์ธ์œผ๋กœ ์„ค์ •์„ ๋„ฃ์—ˆ๋‹ค.

python MaxText/train.py MaxText/configs/base.yml \
    base_output_directory=/path/to/output \
    tokenizer_path=google-t5/t5-large \
    per_device_batch_size=16 \
    gradient_accumulation_steps=2 \
    ici_fsdp_parallelism=16 \
    ici_tensor_parallelism=1 \
    remat_policy=minimal \
    async_checkpointing=true \
    model_name=llama3-8b \
    steps=10000 \
    checkpoint_period=1000 \
    eval_interval=1000 \
    learning_rate=6e-4 \
    dataset_type=hf \
    hf_path=allenai/c4 \
    hf_data_dir=en \
    hf_train_files= \
    hf_eval_split=validation \
    hf_eval_files= \
    hf_access_token=YOUR_HF_ACCESS_TOKEN \
    enable_goodput_recording=false \ // IAM์—์„œ ์„œ๋น„์Šค ๊ณ„์ •์— ๋Œ€ํ•œ ๊ถŒํ•œ์„ ๋ถ€์—ฌํ•ด์•ผ ํ•œ๋‹ค. ๋‚˜๋Š” ๊ทธ๋ƒฅ false๋กœ ๋‘์—ˆ๋‹ค. 
    monitor_goodput=false \ // ์œ„์™€ ๊ฐ™์Œ

ํ•™์Šต์ด ์™„๋ฃŒ๋˜๋ฉด MaxText/llama_mistral_mixtral_orbax_to_hf.py์™€ ๊ฐ™์€ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ํ†ตํ•ด HF Transformers์—์„œ ๋ถˆ๋Ÿฌ์˜ฌ ์ˆ˜ ์žˆ๋Š” ๋ชจ๋ธ๋กœ ๋ณ€ํ™˜ํ•  ์ˆ˜ ์žˆ๋‹ค. ์ดํ›„ HuggingFace Hub์— ์—…๋กœ๋“œํ•  ์ˆ˜๋„ ์žˆ๋‹ค.

์œ„๋Š” ๊ฐ€์žฅ ๊ธฐ๋ณธ์ ์ธ ์„ค์ •์ด๋‹ค. ์ด ์™ธ์—๋„ ๋‹ค์–‘ํ•œ ์„ค์ •์ด ์žˆ์œผ๋‹ˆ MaxText/configs/base.yml์„ ์ฐธ์กฐํ•˜์ž.

์—ฌ๋‹ด

  • maxtext๋Š” ๋‹น์—ฐํžˆ GPU๋กœ๋„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค.

  • maxtext๋Š” ์ƒํ˜ธ ๊ต์ฐจ ์˜ค์—ผ ์—†๋Š” sequence packing๋„ ์ž˜ ์ง€์›ํ•œ๋‹ค! ์ผ๋ฐ˜์ ์ธ ๋ฐ์ดํ„ฐ์…‹์—์„œ 2๋ฐฐ ์ด์ƒ์œผ๋กœ ์†๋„๊ฐ€ ํ–ฅ์ƒ๋œ๋‹ค. ๊ธฐ๋ณธ์ ์œผ๋กœ ํ™œ์„ฑํ™”๋˜์–ด ์žˆ๋‹ค.

  • maxtext์—์„œ๋Š” Flash Attention๊ณผ ์œ ์‚ฌํ•œ Splash Attention์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค.

  • ๋‚˜๋Š” maxtext๋ฅผ forkํ•œ devngho/maxtext๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ๋‹ค. mistral nemo ๋ชจ๋ธ ์ง€์›๊ณผ wandb ์ง€์› ๋“ฑ์„ ์ถ”๊ฐ€ํ–ˆ๋‹ค.

  • GCSFuse๋ฅผ ์‚ฌ์šฉํ•ด GCS์— ๋ฐ์ดํ„ฐ์…‹๊ณผ ๋ชจ๋ธ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ €์žฅํ•  ์ˆ˜ ์žˆ๋‹ค. ๋”ฐ๋กœ ๋””์Šคํฌ๋ฅผ ๋ถ™์ด๋Š” ๊ฒƒ๋ณด๋‹ค ์ €๋ ดํ•˜๊ณ , ๋‚˜์˜์ง€ ์•Š์€ ์†๋„๋ฅผ ๋ณด์—ฌ์ค€๋‹ค.

  • GCP IP ํ• ๋‹น๋Ÿ‰์ด ๋ถ€์กฑํ•˜๋‹ค๋ฉด, TPU๋ฅผ ๋งŒ๋“ค ๋•Œ --internal-ips ์˜ต์…˜์„ ์ค˜์„œ ์™ธ๋ถ€ IP๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ , NAT๋ฅผ ๋งŒ๋“ค์–ด ์™ธ๋ถ€์™€ ํ†ต์‹ ํ•˜๊ฒŒ ํ•  ์ˆ˜ ์žˆ๋‹ค. SSH ์ ‘์† ์‹œ์—๋Š” --tunnel-through-iap ์˜ต์…˜์„ ์ฃผ๋ฉด ๋œ๋‹ค.

3์ค„ ์š”์•ฝ

  • TRC๋ฅผ ํ†ตํ•ด์„œ ๋ฌด๋ฃŒ๋กœ TPU๋ฅผ ์‚ฌ์šฉ, ์—ฐ๊ตฌ๋‚˜ ๊ฐœ์ธ ํ”„๋กœ์ ํŠธ์— ํ™œ์šฉํ•ด๋ณด์ž.
  • LLM์„ ํ•™์Šตํ•˜๋ ค๋ฉด maxtext๋ฅผ ์‚ฌ์šฉํ•˜์ž.
  • ๊ตฌ๊ธ€ ๐Ÿ’•

๊ฐ์ฃผ

  1. [*1] TPU v4์˜ ๊ฒฝ์šฐ 120core 240thread, 384GB RAM
  2. [*2] TPU v2-8, TPU v3-8์€ v3-16 ๋“ฑ์œผ๋กœ ๋ฐ”๊ฟ” ์“ธ ์ˆ˜ ์—†๋‹ค. TPU v4๋Š” v4-32์ฒ˜๋Ÿผ TPU Pod๋กœ ๋ฌถ์–ด ์“ธ ์ˆ˜ ์žˆ๋‹ค.
  3. [*3] ํ•œ๊ตญ์–ด ๋ฐ์ดํ„ฐ์˜ ๊ต์œก์„ฑ ์ ์ˆ˜๋ฅผ ํ‰๊ฐ€ํ•˜๋Š” ๋ชจ๋ธ. ์ž์„ธํ•œ ๋‚ด์šฉ์€ ๋ชจ๋ธ ์นด๋“œ ์ฐธ์กฐ.
  4. [*4] ์ฝ”๋“œ ๋ฐ์ดํ„ฐ์˜ ๊ต์œก์„ฑ ์ ์ˆ˜๋ฅผ Qwen2.5 Coder 32B Instruct ๋ชจ๋ธ๋กœ ํ‰๊ฐ€ํ•œ ๋ฐ์ดํ„ฐ์…‹. ์ž์„ธํ•œ ๋‚ด์šฉ์€ ๋ฐ์ดํ„ฐ์…‹ ์นด๋“œ ์ฐธ์กฐ.
  5. [*5] ๋‚ด ๊ฒฝ์šฐ๋ฅผ ์†Œ๊ฐœํ•˜์ž๋ฉด, ๋‚ด ์งˆ๋ฌธ์ด ๋‹ด๊ธด ์ด๋ฉ”์ผ์„ ๊ธธ๋ฉด ์ดํ‹€ ์•ˆ์— ๋ชจ๋‘ ๋Œ€๋‹ตํ•ด์ค„ ์ •๋„๋กœ ์ •๋ง ์ตœ๊ณ  ์ˆ˜์ค€์˜ ์ง€์›์ด๋ผ๊ณ  ํ•  ์ˆ˜ ์žˆ๋‹ค.
  6. [*6] PyTorch๋กœ๋งŒ ๊ตฌํ˜„๋œ ๋ชจ๋ธ๋„ ๋งŽ๊ณ , Flash Attention ์ง€์›์ด ์•ˆ ๋œ๋‹ค.
  7. [*7] ํ•ด๊ฒฐ ๋ฐฉ๋ฒ•์„ ์ฐพ๊ธฐ ์–ด๋ ค์šด ๋ฌธ์ œ์— ์ฒ˜ํ•  ๊ฐ€๋Šฅ์„ฑ์ด ๋†’๋‹ค. ๋‚ด๊ฐ€ ์‚ฌ์šฉํ–ˆ๋˜ ๋•Œ๋Š” ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ๋„ ๋˜‘๋ฐ”๋กœ ๋˜์ง€ ์•Š์•˜๋˜ ๋ฐ๋‹ค๊ฐ€, ์ž…๋ ฅ shape๊ฐ€ ๋‹ฌ๋ผ์ง€๋ฉด ๊ณ„์† ์žฌ์ปดํŒŒ์ผํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋””๋ฒ„๊น…ํ•˜๊ธฐ ์–ด๋ ค์šด ๋ฌธ์ œ๊ฐ€ ์ผ์–ด๋‚˜๊ธฐ๋„ ํ•œ๋‹ค.

์ฐธ๊ณ  ์ž๋ฃŒ

  1. [1]
    "Cloud TPU ์†Œ๊ฐœ." Google Cloud. Accessed: Aug. 12, 2024. [Online]. Available: https://cloud.google.com/tpu/docs/intro-to-tpu?hl=ko](https://cloud.google.com/tpu/docs/intro-to-tpu?hl=ko
    [โ†‘]
  2. [2]
    A. Gu and T. Dao, "Mamba: Linear-time sequence modeling with selective state spaces," 2023, arXiv: 2312.00752.
    [โ†‘]
  3. [3]
    "TPU Research Cloud - Publications." Google Research. Accessed: Aug. 12, 2024. [Online]. Available: https://sites.research.google/trc/publications/
    [โ†‘]
  4. [4]
    "maxtext/getting_started/Data_Input_Pipeline.md at main ยท AI-Hypercomputer/maxtext." GitHub. Accessed: Sep. 12, 2024. [Online]. Available: https://github.com/AI-Hypercomputer/maxtext/blob/main/getting_started/Data_Input_Pipeline.md](https://github.com/AI-Hypercomputer/maxtext/blob/main/getting_started/Data_Input_Pipeline.md
    [โ†‘]

์šฉ์–ด

label
TPU
Tensor Processing Unit
label
ASIC
Application-Specific Integrated Circuit
label
ML
Machine Learning
label
LLM
Large Language Model
label
GPU
Graphics Processing Unit
label
XLA
Accelerated Linear Algebra
label
JAX
Just Another XLA
label
TRC
TPU Research Cloud
label
HF
HuggingFace
label
GCS
Google Cloud Storage

chevron_left
์ด์ „ ๊ธ€
About
article
ํ˜„์žฌ ๊ธ€
TPU๋กœ LLM ํ•™์Šตํ•˜๊ธฐ (feat. TRC, MaxText)
chevron_right
๋‹ค์Œ ๊ธ€
TPU Troubleshooting
์ธ์šฉํ•˜๊ธฐ
BibTeX
@misc{devngho202420241208trc,
  author       = {Yu, Dongho},
  title        = {TPU๋กœ LLM ํ•™์Šตํ•˜๊ธฐ (feat. TRC, MaxText)},
  howpublished = {\url{https://ngho.dev/posts/20241208trc}},
  year         = {2024},
  month        = {dec},
  note         = {Accessed: 2025-06-17}
}

APA ์œ ๋™ํ˜ธ. (2024๋…„ 12์›” 9์ผ). TPU๋กœ LLM ํ•™์Šตํ•˜๊ธฐ (feat. TRC, MaxText). devngho ๋ธ”๋กœ๊ทธ. https://ngho.dev/posts/20241208trc

Chicago ์œ ๋™ํ˜ธ. โ€œTPU๋กœ LLM ํ•™์Šตํ•˜๊ธฐ (feat. TRC, MaxText).โ€ devngho ๋ธ”๋กœ๊ทธ (blog). 2024๋…„ 12์›” 9์ผ. https://ngho.dev/posts/20241208trc.

MLA ์œ ๋™ํ˜ธ. โ€œTPU๋กœ LLM ํ•™์Šตํ•˜๊ธฐ (feat. TRC, MaxText).โ€ devngho ๋ธ”๋กœ๊ทธ, 2024๋…„ 12์›” 9์ผ, https://ngho.dev/posts/20241208trc.