Google TPU上でのTunixを使用した簡単なFunctionGemmaモデルのファインチューニング
軽量なJAXベースのTunixライブラリをGoogle TPUで使用することで、FunctionGemmaモデルのファインチューニングが高速かつ簡単に行える。
キーポイント
GoogleがFunctionGemmaのTPU上でのファインチューニングをTunixライブラリで簡素化する方法を紹介
無料Colab TPU v5e-1環境でLoRAを用いた効率的なファインチューニングが可能に
JAXベースのTunixは多様なLLM後訓練技法を統一的にサポート
エッジデバイス向け軽量モデルFunctionGemmaの実用化プロセスが加速
影響分析・編集コメントを表示
影響分析
この記事は、GoogleのTPU環境とJAXエコシステムを活用した効率的な軽量LLMファインチューニングの実践的ガイドを提供する。特に無料リソースで実装可能な点は開発者コミュニティへの普及を促進し、エッジAIエージェント開発のハードルを下げる効果が期待される。
編集コメント
Google開発者向けの実践的チュートリアル記事であり、自社技術スタック(TPU/JAX/Tunix)の優位性をアピールしつつ、開発者エコシステムの囲い込みを図る典型的な技術マーケティングコンテンツと言える。
以下は、Google TPU上でTunixを用いたFunctionGemmaのファインチューニングに関する記事の詳細な要約です。
本記事は、小型言語モデル「FunctionGemma」を、GoogleのTPU上で効率的にファインチューニングするための新しい手法を紹介しています。FunctionGemmaは、自然言語を実用的なAPI呼び出しに変換するエージェントを、高速かつ低コストで開発・配備可能にするモデルであり、特にエッジデバイスでの利用に適しています。
従来のGPUを用いたHugging Face TRLライブラリを使った方法に代わり、本稿では「Tunix」というライブラリをTPU上で活用する方法を提案しています。TunixはJAXで実装された軽量ライブラリで、大規模言語モデルの学習後プロセスを効率化することを目的としています。これは拡張JAX AIスタックの一部であり、教師ありファインチューニング、パラメータ効率型ファインチューニング(PEFT)、選好チューニング、強化学習、モデル蒸留など、幅広い最新技術をサポートしています。GemmaやQwen、LLamaなどの最新オープンモデルと互換性があり、各種ハードウェアアクセラレータ上で高い効率性を発揮するように設計されています。
チュートリアルでは、具体的にLoRAを用いた教師ありファインチューニングを、無料枠のColab TPU v5e-1上で実行する手順を示しています。使用するデータセットは、前回のチュートリアルと同様の「Mobile Action」データセットです。
手順の概要は以下の通りです。まず、Hugging Face HubからFunctionGemmaのモデル重みとデータセットをダウンロードします。次に、TPUの並列処理のためにJAXのシャーディング方式を利用しますが、無料Colabの単一コアTPU環境では、シンプルなメッシュを作成します。続いて、Tunixのcreate_model_from_safe_tensors()関数を用いてsafetensors形式から直接モデル重みを読み込み、「Qwix」を使用してアテンション層にLoRAアダプターを適用します。この際、モデルの状態を適切にシャーディング制約下に置きます。
さらに、記事では「完了のみ」の損失計算をサポートするために、トレーニングデータをTunixに供給するカスタムデータセットクラスの定義についても言及しています。これにより、特定のタスクに特化したデータ形式に対応したファインチューニングが可能になります。
まとめると、本記事は、TunixとGoogle TPUを組み合わせることで、FunctionGemmaのファインチューニングを、従来のGPUベースのアプローチとは異なる、効率的でスケーラブルな方法で実行できる実践的な道筋を提供しています。特に、リソースが限られた無料のColab環境でも実行可能な点が強調されており、開発者が手軽に高性能なエージェントモデルを構築するための新たな選択肢を示す内容となっています。
原文を表示
Easy FunctionGemma finetuning with Tunix on Google TPUs
FunctionGemma is a powerful small language model that enables developers to ship fast and cost-effective agents that can translate natural language into actionable API calls, especially on edge devices. In the previous A Guide to Fine-Tuning FunctionGemma blog, our colleague shared some best practices for finetuning FunctionGemma using the Hugging Face TRL library on GPUs. In this post we are going to explore a different path by using Google Tunix to perform the finetuning on TPUs. You can find the complete notebook here.
Tunix is a lightweight library implemented in JAX and designed to streamline the post-training of Large Language Models (LLMs) and it is part of the extended JAX AI Stack. Tunix supports a wide range of modern LLM post-training techniques such as supervised finetuning, Parameter-Efficient Fine-Tuning, preference tuning, reinforcement learning, and model distillation. Tunix works with the latest open models like Gemma, Qwen and LLama, and is designed to work on a large scale of hardware accelerators with high efficiency.
In this tutorial we are going to use LoRA to do supervised finetuning on FunctionGemma and run everything on free-tier Colab TPU v5e-1. We are using the same Mobile Action dataset as in the previous finetuning tutorial.
First, we download the FunctionGemma model weights and the dataset using Hugging Face Hub.
MODEL_ID = "google/functiongemma-270m-it" DATASET_ID = "google/mobile-actions" local_model_path = snapshot_download(repo_id=MODEL_ID, ignore_patterns=["*.pth"]) data_file = hf_hub_download(repo_id=DATASET_ID, filename="dataset.jsonl", repo_type="dataset") Python Copied Tunix leverages JAX sharding schemes for parallelism under the hood. But since free-tier Colab only offers TPU v5e-1 (single core), we are creating a simple mesh without any sharding.
NUM_TPUS = len(jax.devices()) MESH = [(1, NUM_TPUS), ("fsdp", "tp")] if NUM_TPUS > 1 else [(1, 1), ("fsdp", "tp")] mesh = jax.make_mesh(*MESH, axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0])) Python Copied Tunix can directly load the model weights from safetensors via the create_model_from_safe_tensors() function. We then use Qwix to apply the LoRA adapters to the attention layers.
with mesh: base_model = params_safetensors_lib.create_model_from_safe_tensors(local_model_path, model_config, mesh) lora_provider = qwix.LoraProvider( module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj", rank=LORA_RANK, alpha=LORA_ALPHA, ) model_input = base_model.get_model_input() model = qwix.apply_lora_to_model(base_model, lora_provider, rngs=nnx.Rngs(0), **model_input) state = nnx.state(model) pspecs = nnx.get_partition_spec(state) sharded_state = jax.lax.with_sharding_constraint(state, pspecs) nnx.update(model, sharded_state) Python Copied To support the completion-only loss, we define a custom dataset class, which we will use to feed training data into Tunix.
class CustomDataset: def __init__(self, data, tokenizer, max_length=1024): self.data = data self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.data) def __iter__(self): for item in self.data: template_inputs = json.loads(item['text']) prompt_and_completion = self.tokenizer.apply_chat_template( template_inputs['messages'], tools=template_inputs['tools'], tokenize=False, add_generation_prompt=False ) prompt_only = self.tokenizer.apply_chat_template( template_inputs['messages'][:-1], tools=template_inputs['tools'], tokenize=False, add_generation_prompt=True ) tokenized_full = self.tokenizer(prompt_and_completion, add_special_tokens=False) tokenized_prompt = self.tokenizer(prompt_only, add_special_tokens=False) full_ids = tokenized_full['input_ids'] prompt_len = len(tokenized_prompt['input_ids']) if len(full_ids) > self.max_length: full_ids = full_ids[:self.max_length] input_tokens = np.full((self.max_length,), self.tokenizer.pad_token_id, dtype=np.int32) input_tokens[:len(full_ids)] = full_ids input_mask = np.zeros((self.max_length,), dtype=np.int32) if len(full_ids) > prompt_len: mask_end = min(len(full_ids), self.max_length) input_mask[prompt_len:mask_end] = 1 yield peft_trainer.TrainingInput( input_tokens=jnp.array(input_tokens, dtype=jnp.int32), input_mask=jnp.array(input_mask, dtype=jnp.int32) ) Python Copied Next we create the data generators using CustomDataset:
def data_generator(split_data, batch_size): dataset_obj = CustomDataset(split_data, tokenizer, MAX_LENGTH) batch_tokens, batch_masks = [], [] for item in dataset_obj: batch_tokens.append(item.input_tokens) batch_masks.append(item.input_mask) if len(batch_tokens) == batch_size: yield peft_trainer.TrainingInput(input_tokens=jnp.array(np.stack(batch_tokens)), input_mask=jnp.array(np.stack(batch_masks))) batch_tokens, batch_masks = [], [] print("Preparing training data...") train_batches = list(data_generator(train_data, BATCH_SI
関連記事
今日のまとめ
AI日報で今日の重要ニュースをまとめ読み