TPU๋ก LLM ํ์ตํ๊ธฐ (feat. TRC, MaxText)
TPU๊ฐ ๋ฌด์์ธ์ง, ๊ทธ๋ฆฌ๊ณ ์ TPU๋ฅผ ์ฌ์ฉํด์ผ ํ๋์ง์ ๋ํด ์์๋ด ๋๋ค. MaxText ๋ฅผ ์ฌ์ฉํด TPU๋ก ํ์ต์ ์งํํ๋ ๋ฐฉ๋ฒ๊ณผ TPU Reserach Cloud๋ ํจ๊ป ์ดํด๋ด ๋๋ค.

chevron_right ๋ชฉ์ฐจ
TPU
๋?TPU๋ Google์์ ์ปค์คํ ๊ฐ๋ฐํ ASIC[1]. ์์ฆ LLM ์ ํ์ตํ๋ ๋ฐ ์ฐ์ด๋ GPU ๋ณด๋ค ํจ์จ์ ์ผ๋ก ๋๊ท๋ชจ ML ์ฐ์ฐ์ ์ฒ๋ฆฌํ ์ ์๋ค. ํ์ฌ v6e(ํธ๋ฆด๋ฆฌ์)๊น์ง ๊ณต๊ฐ๋์๋ค.
๋ก์ ML ์ํฌ๋ก๋๋ฅผ ๋น ๋ฅด๊ฒ ์ฒ๋ฆฌํ๋ ๋ฐ ์ฌ์ฉ๋๋ค.์ค์ฌ์ฉ ์์ ์ด๋ ค์
TPU๋ GPU๋ณด๋ค ์ด๊ธฐ ์ค์ ์ด ๋ณต์กํ๋ค. PyTorch์ ์ฌ์ฉํ๊ณ ์ ํ๋ค๋ฉด PyTorch/XLAJAX๋ฅผ ์ฌ์ฉํ๋ค๋ฉด ๊ทธ๋๋ง ๋ซ๊ฒ ์ง๋ง ์ฌ์ ํ ๊น๋ค๋ก์ด ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ๋ง๋๊ฒ ๋ ๊ฒ์ด๋ค. CUDA ์ํ๊ณ ๋ฐ์ ๋์จ๋ค๋ฉด ๊ฒช์ด์ผ ํ ๋๊ด์ธ ์ ์ด๋ค. ๋ณธ์ธ์ ์๋ฅผ ๋ค๋ฉด ์ฒ์ TPU๋ฅผ ์ฌ์ฉํ ๋๋ PyTorch/XLA๋ฅผ ์ฌ์ฉํด mamba[2] ๋ชจ๋ธ์ ํ์ตํ๊ณ ์ ํ๋๋ฐ, ์์ญ ๋ถ ๋์ ์งํ์ด ๋์ง ์์๋ค. XLA๋ ๋ฐ์ดํฐ ํํ๊ฐ ๋ฐ๋๋ฉด ์ปดํ์ผ์ ๋ค์ ํด์ผ ํ๋๋ฐ, ์ด ๋๋ฌธ์ ํธํ๋์ง ์๋ ์ฝ๋๊ฐ ๊ต์ฅํ ๋ง๋ค. ์ฌ์ง์ด ๊ด๋ จ ์๋ฃ๋ ๋ถ์คํ ํธ์ด๋ค.
์ ์จ๋ฆํ๊ฒ ๋ ๊ฒ์ด๊ณ ,๊ทธ๋ผ์๋ ๋ถ๊ตฌํ๊ณ GPU ๋์ TPU๋ฅผ ์ฌ์ฉํด์ผ ํ๋ ์ด์ ๋ ๋ฌด์์ผ๊น? TPU๋ฅผ ์ฌ์ฉํด์ผ ํ ์ด์ ๊ฐ ๋ฐ๋ก ๋ค์์ ์๊ฐํ ๋ด์ฉ์ ์๋ค.
TRC
๋ฐ๋ก TPU Research Cloud, TRC๋ค. TRC์ ์ฐธ๊ฐํ๋ฉด ์๋นํ ์์ค์ ์ฑ๋ฅ์ ๊ฐ์ง TPU์ VM[*1]์ ๋ฌด๋ฃ๋ก ์ผ์ ๊ธฐ๊ฐ ๋์ ์ฌ์ฉํ ์ ์๋ค. ๋ด ๊ฒฝ์ฐ๋:
- TPU v2-8 50๊ฐ (์ ์ ํ)
- TPU v3-8 50๊ฐ (์ ์ ํ)
- TPU v4 32๊ฐ (์ ์ ํ)
- TPU v4 32๊ฐ (์จ๋๋งจ๋)
๋ฅผ ์ฒ์์ 30์ผ๋์ ์ ๊ณต๋ฐ์๋ค[*2]. ์ด๋ฅผ ํตํด devngho/ko_edu_classifier_v2_nlpai-lab_KoE5[*3] ๋ฑ์ ๋ชจ๋ธ๊ณผ devngho/the-stack-llm-annotations-v2[*4] ๋ฑ์ ๋ฐ์ดํฐ์ ์ ๊ณต๊ฐํ๋ค.
๋ฌผ๋ก TRC์๋ ์กฐ๊ฑด์ ์๋ค. TRC์ ์ง์์ ๋ฐ์ ์ฐ๊ตฌ๋ฅผ ๋๋ฃ ๊ฒํ ์ถํ๋ฌผ, ์คํ ์์ค ์ฝ๋, ๋ธ๋ก๊ทธ ๊ฒ์๋ฌผ ๋ฑ์ผ๋ก ๊ณต์ ํด์ผ ํ๋ค. ์ด๋ฆ๋ถํฐ Research Cloud์ธ ๋งํผ ์ฐ๊ตฌ ๊ฒฐ๊ณผ๋ฅผ ๊ณต์ ํ๋ ๊ฒ์ด ๋น์ฐํ๋ค. ์ด๋ฏธ ์๋ง์ ๋ ผ๋ฌธ๋ค์์ TRC๋ฅผ ํตํด TPU๋ฅผ ์ฌ์ฉํ๋ค[3]. ์ด ์ธ์๋ ๋น์ฅ ํ๊ตญ์์ ๋๋ฆฌ ์๋ ค์ง beomi/llama-2-ko-7b๋ beomi/kcbert-base ๋ฑ ๋ชจ๋ธ์ TRC ์ง์ TPU๊ฐ ์ฌ์ฉ๋์๋ค.
Shawn์ด TRC์ ๋ํด ๋จ๊ธด ๋๊ธ์ ์ฝ์ด ๋ณด๋ฉด TRC์ ์ ์ฒญํ๋ ๋ฐ ๋ถ๋ด๊ฐ ๊ฐ์ง ๋ง๊ณ ๋จผ์ ์ ์ฒญํด๋ณด๊ธธ ๊ถํ๊ณ ์๋ค. TPU ์ง์ ํ ๋ํ ์น์ ํ๊ฒ ๋์์ค ๊ฒ์ด๋ผ๊ณ ํ๊ณ ์๋ค[*5]. ๋๋ ์ด ์๊ฒฌ์ ์์ ํ ๋์ํ๋ค. ํ์ ํ๊ณ ์ถ์ง๋ง ๋น์ฉ์ด๋ ์์ ๋ฌธ์ ๋ก ๋ชป ํ๋ ์ฐ๊ตฌ๊ฐ ์๋ค๋ฉด ๋ํ ๋์ ์์ด ์ข์ ๊ธฐํ์ผ ๊ฒ์ด๊ณ , ๊ทธ๋ ์ง ์๋๋ผ๋ ์ผ๋จ ์ ์ฒญํด๋ณด์.
TPU๋ก LLM ํ์ตํ๊ธฐ (feat. MaxText)
์ด์ TPU๋ฅผ ์ฌ์ฉํด์ผ ํ ์ด์ ์, ์ด๋ป๊ฒ ์จ๋ณผ ์ ์๋์ง ์๊ฐํ๋ค. ์ด์ TPU๋ก LLM์ ํ์ตํ๋ ๋ฐฉ๋ฒ์ ์์๋ณด์. HF[*6], PyTorch/XLA๋ก ์ฌ์ฉํ์๊ธฐ์๋ FSDPv2๋ก ํ์ตํ ์๋ ์์ง๋ง ์ฌ๋ฌ ๋ฌธ์ ๊ฐ ์๋ค[*7]. ๋์ ์ ์ฐ๋ฆฌ๊ฐ ์ฌ์ฉํ ๊ฒ์ maxtext๋ค. maxtext๋ JAX๋ก ์ฐ์ฌ์ง ๊ณ ์ฑ๋ฅ, ๊ณ ํ์ฅ์ฑ ์คํ์์ค LLM ์ฝ๋๋ฒ ์ด์ค๋ค. maxtext๋ ๊ตฌ๊ธ์์ ๋ง๋ ๋งํผ ์ฒ์๋ถํฐ TPU๋ฅผ ์ผ๋์ ๋๊ณ ๋ง๋ค์ด์ก์ผ๋ฉฐ ์ฌ์ฉํ๊ธฐ ์ฝ๋ค.
transformers๋ JAX์ ๋ํ ์ง์์ด ๋ถ์คํ๊ณTPU ๋ ธ๋ ์์ฑ
TPU๋ ์์๊ฐ ๋ง์ ๋ ธ๋ ์์ฑ์ ์ค๋ ์๊ฐ์ด ๊ฑธ๋ฆด ์ ์๋ค. ์๋์ฒ๋ผ ํ์ ๋ฃ์ด ๋๊ณ ์์ฑ์ ๊ธฐ๋ค๋ฆฌ๋ ๋ฐฉ๋ฒ์ด ์๋ค. ๊ฒฝํ์ ๊ธธ์ด๋ ์ดํ ์ ๋๋ฉด ์์ฑ๋๋ค.
์คํ ๋ ธ๋๋ ๋ณด๋ค ์งง์ ์๊ฐ ๋ด์ ์์ฑ๋๋ ํธ์ด๋, ์ธ์ ์ข ๋ฃ๋ ์ง ๋ชจ๋ฅด๊ธฐ ๋๋ฌธ์ ์ฃผ์ํด์ผ ํ๋ค. ๊ธฐ์กด์ ์ฌ์ฉ๋๋ ์ ์ ํ ๋ ธ๋๋ 24์๊ฐ ์ ํ์ด ์์ผ๋ฏ๋ก ์คํ ๋ ธ๋๋ฅผ ์ฌ์ฉํ์. TRC ํ์ ๋ฌธ์ํ์ ๋ ์ ์ ํ ๋ฌด๋ฃ ํ ๋น๋์ผ๋ก ์คํ ๋ ธ๋๋ฅผ ๋ง๋ค ์ ์๋ค๋ ๋ต๋ณ์ ๋ฐ์๋ค.
gcloud alpha compute tpus queued-resources create $QUEUED_RESOURCE_ID \
--node-id $NODE_ID \
--project $PROJECT_ID \
--zone $REGION \
--accelerator-type $NODE_TYPE \
--runtime-version tpu-ubuntu2204-base \
--spot \ // ์คํ ๋
ธ๋๋ฅผ ๋ง๋ค๊ณ ์ ํ๋ค๋ฉด ์ถ๊ฐ.
--scopes=https://www.googleapis.com/auth/devstorage.read_only,https://www.googleapis.com/auth/logging.write,https://www.googleapis.com/auth/monitoring.write,https://www.googleapis.com/auth/pubsub,https://www.googleapis.com/auth/service.management.readonly,https://www.googleapis.com/auth/servicecontrol,https://www.googleapis.com/auth/trace.append,https://www.googleapis.com/auth/devstorage.full_control
TPU ๋ ธ๋ ์ ์
๋ ธ๋๊ฐ ์์ฑ๋๋ฉด SSH๋ก ์ ์ํ ์ ์๋ค. ์๋ ๋ช ๋ น์ด๋ฅผ ํตํด ์ ์ํ์.
gcloud alpha compute tpus tpu-vm ssh $node --zone=$region --project=$PROJECT_ID --worker=0
screen์ ๋์ ๋ช ๋ น์ด๋ฅผ ์คํํ๋ ค๋ฉด ์๋์ ๊ฐ์ด ํ์. ๋๋ ๋ช ๋ น์ด๋ฅผ gist์ ์ฌ๋ ค๋๊ณ ์ฌ๋ฌ ๋ ธ๋์์ ์คํํ๋ค. gist๋ ๋น์ฐํ๊ฒ private๋ก ๋ง๋ค์ด์ผ ํ๋ค.
gcloud alpha compute tpus tpu-vm ssh $node \
--zone=$region \
--project=$PROJECT_ID \
--worker=all \
--command="screen -L -Logfile logfile.txt -d -m bash -c \"bash <(curl -sL $path)\""
์ค์น
๋จผ์ repo๋ฅผ ํด๋ก ํ๊ณ ํ์ํ ํจํค์ง๋ฅผ ์ค์นํ๋ค.
git clone AI-Hypercomputer/maxtext
cd maxtext
pip install -r ./requirements.txt
๋ฐ์ดํฐ์ ์ค๋น
maxtext์์๋ ํฌ๊ฒ 3๊ฐ์ง ๋ฐฉ๋ฒ์ผ๋ก ๋ฐ์ดํฐ์ ์ ๋ก๋ฉํ ์ ์๋ค. HuggingFace, Grain, TFDS๋ค[4]. ์ฌ๊ธฐ์๋ HuggingFace์ ์์๋ฅผ ๋ณด์.
# huggingface hub์์ ์คํธ๋ฆฌ๋ฐํ๋ ์์๋ค.
dataset_type: hf
hf_path: 'allenai/c4' # for using https://huggingface.co/datasets/allenai/c4
hf_data_dir: 'en'
hf_train_files: ''
# set eval_interval > 0 to use the specified eval dataset, otherwise, only metrics on the train set will be calculated.
eval_interval: 10000
hf_eval_split: 'validation'
hf_eval_files: ''
# for HF pipeline, tokenizer_path can be a path in HuggingFace Hub,
# or a local path containing tokenizer in a format supported by transformers.AutoTokenizer
tokenizer_path: 'google-t5/t5-large' # for using https://huggingface.co/google-t5/t5-large
hf_access_token: '' # provide token if using gated dataset or tokenizer
# GCS์ ์๋ ๋ฐ์ดํฐ์
์ ์ฌ์ฉํ๋ ์์๋ค.
dataset_type: hf
hf_path: 'parquet' # or json, arrow, etc.
hf_data_dir: ''
hf_train_files: 'gs://<bucket>/<folder>/*-train-*.parquet' # match the train files
# set eval_interval > 0 to use the specified eval dataset. Otherwise, only metrics on the train set will be calculated.
eval_interval: 10000
hf_eval_split: ''
hf_eval_files: 'gs://<bucket>/<folder>/*-validation-*.parquet' # match the val files
# for HF pipeline, tokenizer_path can be a path in HuggingFace Hub,
# or a local path containing tokenizer in a format supported by transformers.AutoTokenizer
tokenizer_path: 'google-t5/t5-large' # for using https://huggingface.co/google-t5/t5-large
ํ์ต ์ค์
ํ์ต์ ๊ฐ์ฅ ๊ธฐ์ด์ ์ธ ๋ชจ๋ธ ์ํคํ ์ฒ ๋ฑ์ ์ค์ ํด๋ณด์. ์ฌ๊ธฐ์๋ ๊ธฐ๋ณธ์ ์ธ ์ค์ ๋ง ์๊ฐํ๋ค. ์์ธํ ์ค์ ์ MaxText/configs/base.yml์์ ๋ชจ๋ ํ์ธํ ์ ์๋ค.
per_device_batch_size: 16 # ๊ธฐ๊ธฐ ๋น ๋ฐฐ์น ํฌ๊ธฐ.
gradient_accumulation_steps: 2 # ๊ทธ๋๋์ธํธ ๋์ ์คํ
.
ici_fsdp_parallelism: 16 # FSDP ๋ณ๋ ฌ์ฑ. ICI๋ ํ ์ฌ๋ผ์ด์ค ๋ด์ ์๋ ์นฉ๋ค ๊ฐ์ ๋คํธ์ํฌ๋ฅผ ์๋ฏธํ๋ค.
ici_tensor_parallelism: 1 # TP ๋ณ๋ ฌ์ฑ.
remat_policy: minimal # remat ์ ์ฑ
. minimal, full ๋ฑ์ด ์๋ค. ์ ์์๋ก ์ฑ๋ฅ์ ์ข์์ง์ง๋ง ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด ๋์ด๋๋ค.
model_name: llama3-8b # ๋ชจ๋ธ ์ํคํ
์ฒ.
steps: 10000 # ์ด ์คํ
์.
checkpoint_period: 1000 # ์ฒดํฌํฌ์ธํธ ์ ์ฅ ์ฃผ๊ธฐ.
eval_interval: 1000 # eval ์ฃผ๊ธฐ.
learning_rate: 6e-4 # ํ์ต๋ฅ .
async_checkpointing: true # ์ฒดํฌํฌ์ธํธ ๋น๋๊ธฐ ์ ์ฅ.
ํ์ต
์ด์ ํ์ต์ ์์ํด๋ณด์. ์ค์ ์ ๋ฐ๋ก yaml ํ์ผ์ ์ ์ฅํด๋ ์ข์ง๋ง, ์ฌ๊ธฐ์๋ ์ธ๋ผ์ธ์ผ๋ก ์ค์ ์ ๋ฃ์๋ค.
python MaxText/train.py MaxText/configs/base.yml \
base_output_directory=/path/to/output \
tokenizer_path=google-t5/t5-large \
per_device_batch_size=16 \
gradient_accumulation_steps=2 \
ici_fsdp_parallelism=16 \
ici_tensor_parallelism=1 \
remat_policy=minimal \
async_checkpointing=true \
model_name=llama3-8b \
steps=10000 \
checkpoint_period=1000 \
eval_interval=1000 \
learning_rate=6e-4 \
dataset_type=hf \
hf_path=allenai/c4 \
hf_data_dir=en \
hf_train_files= \
hf_eval_split=validation \
hf_eval_files= \
hf_access_token=YOUR_HF_ACCESS_TOKEN \
enable_goodput_recording=false \ // IAM์์ ์๋น์ค ๊ณ์ ์ ๋ํ ๊ถํ์ ๋ถ์ฌํด์ผ ํ๋ค. ๋๋ ๊ทธ๋ฅ false๋ก ๋์๋ค.
monitor_goodput=false \ // ์์ ๊ฐ์
ํ์ต์ด ์๋ฃ๋๋ฉด MaxText/llama_mistral_mixtral_orbax_to_hf.py
์ ๊ฐ์ ์คํฌ๋ฆฝํธ๋ฅผ ํตํด HF Transformers์์ ๋ถ๋ฌ์ฌ ์ ์๋ ๋ชจ๋ธ๋ก ๋ณํํ ์ ์๋ค. ์ดํ HuggingFace Hub์ ์
๋ก๋ํ ์๋ ์๋ค.
์๋ ๊ฐ์ฅ ๊ธฐ๋ณธ์ ์ธ ์ค์ ์ด๋ค. ์ด ์ธ์๋ ๋ค์ํ ์ค์ ์ด ์์ผ๋ MaxText/configs/base.yml์ ์ฐธ์กฐํ์.
์ฌ๋ด
maxtext๋ ๋น์ฐํ GPU๋ก๋ ์ฌ์ฉํ ์ ์๋ค.
maxtext๋ ์ํธ ๊ต์ฐจ ์ค์ผ ์๋ sequence packing๋ ์ ์ง์ํ๋ค! ์ผ๋ฐ์ ์ธ ๋ฐ์ดํฐ์ ์์ 2๋ฐฐ ์ด์์ผ๋ก ์๋๊ฐ ํฅ์๋๋ค. ๊ธฐ๋ณธ์ ์ผ๋ก ํ์ฑํ๋์ด ์๋ค.
maxtext์์๋ Flash Attention๊ณผ ์ ์ฌํ Splash Attention์ ์ฌ์ฉํ ์ ์๋ค.
๋๋ maxtext๋ฅผ forkํ devngho/maxtext๋ฅผ ์ฌ์ฉํ๊ณ ์๋ค. mistral nemo ๋ชจ๋ธ ์ง์๊ณผ wandb ์ง์ ๋ฑ์ ์ถ๊ฐํ๋ค.
GCSFuse๋ฅผ ์ฌ์ฉํด GCS
์ ๋ฐ์ดํฐ์ ๊ณผ ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํ ์ ์๋ค. ๋ฐ๋ก ๋์คํฌ๋ฅผ ๋ถ์ด๋ ๊ฒ๋ณด๋ค ์ ๋ ดํ๊ณ , ๋์์ง ์์ ์๋๋ฅผ ๋ณด์ฌ์ค๋ค.GCP IP ํ ๋น๋์ด ๋ถ์กฑํ๋ค๋ฉด, TPU๋ฅผ ๋ง๋ค ๋
--internal-ips
์ต์ ์ ์ค์ ์ธ๋ถ IP๋ฅผ ์ฌ์ฉํ์ง ์๊ณ , NAT๋ฅผ ๋ง๋ค์ด ์ธ๋ถ์ ํต์ ํ๊ฒ ํ ์ ์๋ค. SSH ์ ์ ์์๋--tunnel-through-iap
์ต์ ์ ์ฃผ๋ฉด ๋๋ค.
3์ค ์์ฝ
- TRC๋ฅผ ํตํด์ ๋ฌด๋ฃ๋ก TPU๋ฅผ ์ฌ์ฉ, ์ฐ๊ตฌ๋ ๊ฐ์ธ ํ๋ก์ ํธ์ ํ์ฉํด๋ณด์.
- LLM์ ํ์ตํ๋ ค๋ฉด maxtext๋ฅผ ์ฌ์ฉํ์.
- ๊ตฌ๊ธ ๐
๊ฐ์ฃผ
- [*1] TPU v4์ ๊ฒฝ์ฐ 120core 240thread, 384GB RAM
- [*2] TPU v2-8, TPU v3-8์ v3-16 ๋ฑ์ผ๋ก ๋ฐ๊ฟ ์ธ ์ ์๋ค. TPU v4๋ v4-32์ฒ๋ผ TPU Pod๋ก ๋ฌถ์ด ์ธ ์ ์๋ค.
- [*3] ํ๊ตญ์ด ๋ฐ์ดํฐ์ ๊ต์ก์ฑ ์ ์๋ฅผ ํ๊ฐํ๋ ๋ชจ๋ธ. ์์ธํ ๋ด์ฉ์ ๋ชจ๋ธ ์นด๋ ์ฐธ์กฐ.
- [*4] ์ฝ๋ ๋ฐ์ดํฐ์ ๊ต์ก์ฑ ์ ์๋ฅผ Qwen2.5 Coder 32B Instruct ๋ชจ๋ธ๋ก ํ๊ฐํ ๋ฐ์ดํฐ์ . ์์ธํ ๋ด์ฉ์ ๋ฐ์ดํฐ์ ์นด๋ ์ฐธ์กฐ.
- [*5] ๋ด ๊ฒฝ์ฐ๋ฅผ ์๊ฐํ์๋ฉด, ๋ด ์ง๋ฌธ์ด ๋ด๊ธด ์ด๋ฉ์ผ์ ๊ธธ๋ฉด ์ดํ ์์ ๋ชจ๋ ๋๋ตํด์ค ์ ๋๋ก ์ ๋ง ์ต๊ณ ์์ค์ ์ง์์ด๋ผ๊ณ ํ ์ ์๋ค.
- [*6] PyTorch๋ก๋ง ๊ตฌํ๋ ๋ชจ๋ธ๋ ๋ง๊ณ , Flash Attention ์ง์์ด ์ ๋๋ค.
- [*7] ํด๊ฒฐ ๋ฐฉ๋ฒ์ ์ฐพ๊ธฐ ์ด๋ ค์ด ๋ฌธ์ ์ ์ฒํ ๊ฐ๋ฅ์ฑ์ด ๋๋ค. ๋ด๊ฐ ์ฌ์ฉํ๋ ๋๋ ์ฒดํฌํฌ์ธํธ ์ ์ฅ๋ ๋๋ฐ๋ก ๋์ง ์์๋ ๋ฐ๋ค๊ฐ, ์ ๋ ฅ shape๊ฐ ๋ฌ๋ผ์ง๋ฉด ๊ณ์ ์ฌ์ปดํ์ผํ๊ธฐ ๋๋ฌธ์ ๋๋ฒ๊น ํ๊ธฐ ์ด๋ ค์ด ๋ฌธ์ ๊ฐ ์ผ์ด๋๊ธฐ๋ ํ๋ค.
์ฐธ๊ณ ์๋ฃ
- [1] "Cloud TPU ์๊ฐ." Google Cloud. Accessed: Aug. 12, 2024. [Online]. Available: https://cloud.google.com/tpu/docs/intro-to-tpu?hl=ko](https://cloud.google.com/tpu/docs/intro-to-tpu?hl=ko[โ]
- [2] A. Gu and T. Dao, "Mamba: Linear-time sequence modeling with selective state spaces," 2023, arXiv: 2312.00752.[โ]
- [3] "TPU Research Cloud - Publications." Google Research. Accessed: Aug. 12, 2024. [Online]. Available: https://sites.research.google/trc/publications/[โ]
- [4] "maxtext/getting_started/Data_Input_Pipeline.md at main ยท AI-Hypercomputer/maxtext." GitHub. Accessed: Sep. 12, 2024. [Online]. Available: https://github.com/AI-Hypercomputer/maxtext/blob/main/getting_started/Data_Input_Pipeline.md](https://github.com/AI-Hypercomputer/maxtext/blob/main/getting_started/Data_Input_Pipeline.md[โ]
์ฉ์ด
- labelTPU
- Tensor Processing Unit
- labelASIC
- Application-Specific Integrated Circuit
- labelML
- Machine Learning
- labelLLM
- Large Language Model
- labelGPU
- Graphics Processing Unit
- labelXLA
- Accelerated Linear Algebra
- labelJAX
- Just Another XLA
- labelTRC
- TPU Research Cloud
- labelHF
- HuggingFace
- labelGCS
- Google Cloud Storage
์ธ์ฉํ๊ธฐ
@misc{devngho202420241208trc,
author = {Yu, Dongho},
title = {TPU๋ก LLM ํ์ตํ๊ธฐ (feat. TRC, MaxText)},
howpublished = {\url{https://ngho.dev/posts/20241208trc}},
year = {2024},
month = {dec},
note = {Accessed: 2025-06-17}
}
APA ์ ๋ํธ. (2024๋ 12์ 9์ผ). TPU๋ก LLM ํ์ตํ๊ธฐ (feat. TRC, MaxText). devngho ๋ธ๋ก๊ทธ. https://ngho.dev/posts/20241208trc
Chicago ์ ๋ํธ. โTPU๋ก LLM ํ์ตํ๊ธฐ (feat. TRC, MaxText).โ devngho ๋ธ๋ก๊ทธ (blog). 2024๋ 12์ 9์ผ. https://ngho.dev/posts/20241208trc.
MLA ์ ๋ํธ. โTPU๋ก LLM ํ์ตํ๊ธฐ (feat. TRC, MaxText).โ devngho ๋ธ๋ก๊ทธ, 2024๋ 12์ 9์ผ, https://ngho.dev/posts/20241208trc.