MedQA:CUDA を不要とする AMD ROCm での臨床 AI 微調整
Hugging Face の記事は、NVIDIA CUDA に依存せず AMD ROCm を活用して医療用 AI モデル MedQA を LoRA でファインチューニングする手法を公開し、ハードウェアの多様性を推進している。
キーポイント
CUDA 依存からの脱却と AMD ROCm の実証
医療 AI の開発において NVIDIA GPU と CUDA がデファクトスタンダードとなっている現状に対し、AMD Instinct MI300X を用いた完全なトレーニングパイプラインを実装し、その有効性を示した。
高リスク領域における医療 QA モデルの構築
誤答が危険を招く可能性のある臨床 MCQ(多肢選択問題)タスクに対し、正解だけでなく推論過程の説明も出力する MedQA モデルを Qwen3-1.7B をベースに LoRA でファインチューニングした。
オープンソースコミュニティへの技術貢献
lablab.ai の AMD Developer Hackathon に出品された本プロジェクトのコード、モデル、デモが Hugging Face Hub および GitHub で公開され、他者の検証と再利用を可能にしている。
LoRAによる効率的なファインチューニング
15億パラメータのモデル全体ではなく、PEFTライブラリを用いたLoRA(Low-Rank Adaptation)により、アテンション層に小さな訓練可能な行列を注入して学習を行います。
極小の訓練可能パラメータ数
15億パラメータ中わずか約220万(0.14%)のみが訓練対象となり、これによりメモリ使用量が抑えられ、トレーニング速度が向上します。
fp16の採用とbfloat16の回避
初期実験でbfloat16を使用するとNaN損失が発生したため、標準的なfp16に切り替えることで問題を解決しました。
LoRAによる軽量なモデル保存
トレーニング完了後、数GBのフルモデルチェックポイントではなく、数MBのLoRAアダプタ重みとして保存されます。
影響分析・編集コメントを表示
影響分析
このニュースは、AI インフラストラクチャにおける NVIDIA の独占的な地位に対する重要な挑戦を示しており、特に医療のような高信頼性が求められる分野において、ハードウェアベンダー間の競争と選択肢の拡大を促す意義があります。AMD ROCm の成熟度向上が実証されたことで、コスト削減やサプライチェーンの多様化を目指す組織にとって具体的な道筋を提供するものです。
編集コメント
医療 AI の安全性と信頼性が問われる中で、ハードウェアの多様性を確保する技術的実証は極めて重要です。NVIDIA 依存からの脱却を可能にする ROCm の成熟度は、今後の業界標準形成において注目すべき指標と言えます。
*AMD Developer Hackathon にて lablab.ai で開催されたイベント向けに、AMD MI300X を使用して MedMCQA で Qwen3-1.7B の LoRA(Low-Rank Adaptation)ファインチューニングを行う完全なガイド。
アイデア
医療用質問応答は、そのリスクが極めて高いタスクの一つです。臨床的な多肢選択問題で自信を持って間違った答えを選ぶモデルは、単に間違っているだけでなく、危険です。同時に、オープンソースの医療 AI 関連の取り組みのほとんどは、NVIDIA GPU を持っていることを前提としています。CUDA がデフォルトであり、それ以外の選択肢は後回しにされがちです。
このプロジェクトはその前提に挑戦するものです。
MedQA は、AMD ハードウェア上で ROCm(Radeon Open Compute Platform)のみを使用して完全に構築された、LoRA ファインチューニング済みの臨床用質問応答モデルです。多肢選択形式の医療問題を入力すると、正解の記号 *および* 推論根拠に関する臨床的な説明を返します。データ読み込みからアダプタのエクスポートに至るまでのトレーニングパイプライン全体が、CUDA の依存関係を一切持たずに AMD Instinct MI300X 上で実行されます。
- 🤗 HuggingFace Hub 上のモデル: HK2184/medqa-qwen3-lora
- 🚀 ライブデモ: HuggingFace Spaces
- 💻 GitHub: MedQA-Medical-AI-on-AMD-ROCm
なぜ AMD ROCm か?
AMD Instinct MI300X は、1 つのデバイスに 192 GB の HBM3(High Bandwidth Memory)メモリを搭載した画期的なハードウェアです。LLM(大規模言語モデル)のファインチューニングにおいて、VRAM(ビデオメモリ)はしばしばボトルネックとなります。これはバッチサイズやシーケンス長を決定し、量子化が必要かどうかさえも左右します。192 GB の利用可能なメモリがあるため、4 ビットや 8 ビットの量子化といったハックは一切行わず、Qwen3-1.7B を LoRA でフル fp16(16 ビート浮動小数点)でトレーニングすることができました。
さらに重要なのは、HuggingFace エコシステム(Transformers, PEFT, TRL, Accelerate)が ROCm 上でシームレスに動作することを証明することでした。それは実際に可能です。CUDA で実行されるのと同じトレーニングコードを、3 つの環境変数を設定するだけで ROCm でも実行できます:
os.environ["ROCR_VISIBLE_DEVICES"] = "0"
os.environ["HIP_VISIBLE_DEVICES"] = "0"
os.environ["HSA_OVERRIDE_GFX_VERSION"] = "9.4.2"
これだけです。コードの変更も、カスタムカーネルの作成も、CUDA 互換性の shim(調整層)も不要です。
データセット:MedMCQA
MedMCQA は、インドの医学入学試験(AIIMS, USMLE スタイル)から派生した大規模な多肢選択式質問データセットです。各例には以下の要素が含まれます:
- 臨床に関する質問
- 4 つの回答選択肢(A~D)
- 正解のインデックス
- オプションの自由記述による解説(exp フィールド)
本プロジェクトでは、2,000 のトレーニングサンプルを使用しました。これは、意味のあるファインチューニングが短期間でも達成可能であることを示すためにあえて小さく設定したスライスです。MI300X 上でのトレーニングには約 5 分かかりました。
モデル:Qwen3-1.7B
ベースモデルは Qwen/Qwen3-1.7B です。これはアリババ社の最新の小規模言語モデルです。17 億パラメータというコンパクトなサイズでありながら、低コストでファインチューニングが可能で、かつ一貫した臨床推論を生成する能力も備えています。trust_remote_code=True をサポートしており、HuggingFace Transformers でクリーンに読み込み可能です。
プロンプト形式
プロンプト形式の一貫性は、指示微調整において極めて重要です。すべてのトレーニング例と推論呼び出しは、同じテンプレートを使用します:
質問:
{question}
オプション:
A) {opa}
B) {opb}
C) {opc}
D) {opd}
回答:
{answer_letter}) {answer_text}
解説:
{explanation}
トレーニング中は、モデルは回答と解説を含む完全なシーケンスを参照します。推論時には、「### Answer:\n」までのすべての情報を提供し、そこからモデルに完成させます。
LoRA を用いたトレーニング
15 億パラメータすべてを微調整するのではなく、PEFT ライブラリを通じてLoRA (Low-Rank Adaptation: 低ランク適応)を使用します。LoRA はアテンション層に小さな訓練可能なランク分解行列を注入し、ベースの重みは凍結したままに保ちます。
LoRA の設定
from peft import LoraConfig, get_peft_model, TaskType
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8,
lora_alpha=16,
lora_dropout=0.05,
target_modules=["q_proj", "v_proj"],
bias="none",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
trainable params: 2,228,224 || all params: 1,543,901,184 || trainable%: 0.1443
モデルの 15 億パラメータのうち、約 220 万のみが訓練されます。これによりメモリ使用量が低く抑えられ、トレーニングが高速化されます。
トレーニング引数
from transformers import TrainingArguments
args = TrainingArguments(
output_dir="./outputs",
num_train_epochs=2,
per_device_train_batch_size=4,
gradient_accumulation_steps=4, # 実効バッチサイズ = 16
learning_rate=2e-4,
fp16=True,
bf16=False,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
gradient_checkpointing=True,
optim="adamw_torch",
warmup_ratio=0.05,
lr_scheduler_type="cosine",
report_to="none",
)
いくつか注目に値する点があります:
- fp16=True, bf16=False — 標準的なfp16(半精度浮動小数点数)を使用しています。bfloat16(ブロード浮動小数点16ビット)での初期実験ではNaN(数値非該当)損失が発生しましたが、fp16に切り替えることで完全に解消されました。
- gradient_checkpointing=True — 計算資源とメモリをトレードオフします。MI300X では VRAM が 192 GB あるため厳密には必須ではありませんが、より小さな GPU での再現性を高めるための良いプラクティスです。
- gradient_accumulation_steps=4 — 物理バッチサイズ 4 で実効バッチサイズ 16 を実現します。
- ウォームアップ付きの余弦 LR スケジュール(学習率スケジューリング)— 短期間のトレーニングでは平坦なスケジュールよりも収束が滑らかになります。
完全なトレーニングループ
from transformers import DataCollatorForSeq2Seq, Trainer
collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
padding=True,
pad_to_multiple_of=8,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_ds,
eval_dataset=val_ds,
data_collator=collator,
)
trainer.train()
アダプターとトークナイザーを保存
model.save_pretrained("./outputs")
tokenizer.save_pretrained("./outputs")
トレーニング完了後、./outputs ディレクトリには LoRA アダプタの重みが保存されます。これは数 MB のファイルであり、フルサイズの多 GB モデルチェックポイントとは異なります。
推論
推論時にはベースモデルを読み込み、LoRA アダプタを接続し、必要に応じて重みをマージします:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
tokenizer = AutoTokenizer.from_pretrained("./outputs", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
base_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-1.7B",
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
model = PeftModel.from_pretrained(base_model, "./outputs")
model.eval()
生成処理では、モデルがループするのを防ぐために反復ペナルティ(repetition penalty)を伴う貪欲デコーディング(greedy decoding: do_sample=False)を使用します:
def generate(prompt, model, tokenizer):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=200,
do_sample=False,
temperature=1.0,
repetition_penalty=1.1,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
new_tokens = output[0][inputs["input_ids"].shape[-1]:]
return tokenizer.decode(new_tokens, skip_special_tokens=True)
出力例
質問: 以下のうち、高血圧性緊急状態に対する第一選択治療はどれですか?
A) 経口アムロジピン
B) 静注ラベタロールまたは静注ニトロプルシド
C) 舌下ニフェジピン
D) 筋注ヒドララジン
モデルの出力:
B) 静注ラベタロールまたは静注ニトロプルシド
解説:
静脈内投与されるラベタロール(β遮断薬)またはニトロプルシドは、緊急時において血圧を急速に低下させます。経口製剤は、臓器障害を防ぐために即座の血圧管理が必要な高血圧性緊急事態に対しては作用が遅すぎます。
このモデルは単に文字を出力するだけでなく、*なぜそうなのか*という理由も説明します。これこそが臨床的に有用である所以です。
HuggingFace Hub からの読み込み
ファインチューニングされたアダプターは公衆に公開されています。リポジトリをクローンすることなく、直接読み込むことができます:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-1.7B", trust_remote_code=True
)
base = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-1.7B",
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
model = PeftModel.from_pretrained(base, "HK2184/medqa-qwen3-lora")
model = model.merge_and_unload()
model.eval()
課題と解決策
AMD ROCm プロジェクトに、戦争の物語(苦労話)セクションがないことなどあり得ません。私たちが直面した問題はこちらです:
課題
根本原因
対策
NaN 損失 (NaN loss)
混合精度計算の不安定性
bfloat16 から fp16 に切り替え
GPU が検出されない
ROCm の環境変数が不足している
ROCR_VISIBLE_DEVICES、HIP_VISIBLE_DEVICES、HSA_OVERRIDE_GFX_VERSION を設定
bitsandbytes の非対応
bitsandbytes の ROCm ビルドが存在しない
量子化は完全に廃止 — MI300X は十分な VRAM を備えているため
推論出力のゴミデータ
トークナイザーのパディング設定が誤っていた
pad_token = eos_token とし、padding_side を修正した
Trainer の評価エラー
Transformers バージョンの不整合
transformers>=4.40.0 に固定した
bitsandbytes の問題には注記が必要である:NVIDIA ハードウェア上では、モデルをメモリに収めるために 4 ビット量子化がしばしば*必須*となる。しかし、192 GB の HBM3 を搭載した MI300X では、それは単に不要である。これは真のハードウェア上の利点であり、トレーニングがクリーンになり、量子化によるアーティファクトが発生しない。
結果
| Metric | Value |
|---|---|
| Trainable parameters | ~2.2M (全体の 0.15%) |
| Training time on MI300X | ~5 分 |
| Dataset size used | 2,000 サンプル |
| Baseline MedMCQA accuracy | ~45% |
| Framework | PyTorch + ROCm 6.1 |
実際に試す
GPU がなくても大丈夫。 ライブ Gradio デモは HuggingFace Spaces で動作している(CPU 推論):
AMD ハードウェアをお持ちですか? リポジトリをクローンしてネイティブで実行してください:
git clone https://github.com/HK2184/MedQA-Medical-AI-on-AMD-ROCm.git
cd MedQA-Medical-AI-on-AMD-ROCm
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1
pip install transformers datasets peft accelerate trl gradio
python train.py # 約 5 分
python infer.py # サンプル質問を実行
python app.py # Gradio UI を起動次のステップ
このプロジェクトは、パイプラインが機能することを証明しました。次のステップは、スケーリングと堅牢化に関するものです:
- より大規模なデータセット — 完全な MedMCQA コーパス(約 180,000 問)および PubMedQA を用いてトレーニングを行う
- 信頼性スコアリング — 回答とともに補正された信頼性推定値を追加する
- RAG の統合 — リアルタイムの医療文献検索に基づき、回答を裏付ける
- 評価ハーンネス — トレーニング分割を超えた、適切なホールドアウト精度ベンチマークを行う
結論
MedQA は、オープンソースの AMD ハードウェア上で能力があり説明可能な医療 AI を構築することが可能であるだけでなく、そのプロセスは単純であることを示しています。HuggingFace エコシステムの ROCm 互換性は実際に非常に良好です。MI300X のメモリ余裕により、エンジニアリング上の問題のカテゴリー全体が排除されました。また、LoRA(Low-Rank Adaptation)を用いれば、1.7B パラメータのモデルをファインチューニングする作業はわずか 5 分で完了します。
AMD ROCm で構築を進めていて壁にぶつかった場合、上記の対策で数時間を節約できるはずです。また、医療 AI を構築している場合は、単なる精度よりも説明可能性を重視する姿勢を真剣に受け止める価値があります。
*lablab.ai(https://lablab.ai/)で開催された AMD Developer Hackathon 向けに作成 · AMD ROCm と HuggingFace エコシステムによって駆動*
— Harikrishna Sivanand Iyer および Srijan Sivaram A
原文を表示
*A complete walkthrough of LoRA fine-tuning Qwen3-1.7B on MedMCQA using AMD MI300X, built for the AMD Developer Hackathon on lablab.ai.*
The Idea
Medical question answering is one of those tasks where the stakes are genuinely high. A model that confidently picks the wrong answer on a clinical MCQ isn't just wrong — it's dangerous. At the same time, most open-source medical AI work assumes you have an NVIDIA GPU. CUDA is the default. Everything else is an afterthought.
This project challenges that assumption.
MedQA is a LoRA fine-tuned clinical question-answering model built entirely on AMD hardware using ROCm. It takes a multiple-choice medical question and returns both the correct answer letter *and* a clinical explanation of the reasoning. The entire training pipeline — from data loading to adapter export — runs on an AMD Instinct MI300X without a single CUDA dependency.
- 🤗 Model on HuggingFace Hub: HK2184/medqa-qwen3-lora
- 🚀 Live Demo: HuggingFace Spaces
- 💻 GitHub: MedQA-Medical-AI-on-AMD-ROCm
Why AMD ROCm?
The AMD Instinct MI300X is a remarkable piece of hardware: 192 GB of HBM3 memory in a single device. For LLM fine-tuning, VRAM is often the binding constraint — it dictates batch size, sequence length, and whether you need to quantize at all. With 192 GB available, we trained Qwen3-1.7B with LoRA in full fp16 without any 4-bit or 8-bit quantization hacks.
More importantly, the goal was to prove that the HuggingFace ecosystem — Transformers, PEFT, TRL, Accelerate — works seamlessly on ROCm. It does. The same training code that runs on CUDA runs on ROCm with three environment variables set:
os.environ["ROCR_VISIBLE_DEVICES"] = "0"
os.environ["HIP_VISIBLE_DEVICES"] = "0"
os.environ["HSA_OVERRIDE_GFX_VERSION"] = "9.4.2"
That's it. No code changes. No custom kernels. No CUDA compatibility shims.
The Dataset: MedMCQA
MedMCQA is a large-scale multiple-choice question dataset derived from Indian medical entrance exams (AIIMS, USMLE-style). Each example contains:
- A clinical question
- Four answer options (A–D)
- The correct answer index
- An optional free-text explanation (exp field)
For this project we used 2,000 training samples — a deliberately small slice to demonstrate that meaningful fine-tuning is achievable quickly. Training took approximately 5 minutes on the MI300X.
Model: Qwen3-1.7B
The base model is Qwen/Qwen3-1.7B — Alibaba's latest small-scale language model. At 1.7 billion parameters it's compact enough to fine-tune cheaply but capable enough to produce coherent clinical reasoning. It supports trust_remote_code=True and loads cleanly with HuggingFace Transformers.
The Prompt Format
Consistency in prompt formatting is critical for instruction fine-tuning. Every training example and every inference call uses the same template:
### Question:
{question}
### Options:
A) {opa}
B) {opb}
C) {opc}
D) {opd}
### Answer:
{answer_letter}) {answer_text}
### Explanation:
{explanation}
During training the model sees the full sequence including the answer and explanation. During inference we provide everything up to ### Answer:\n and let the model complete from there.
Training with LoRA
Rather than fine-tuning all 1.5 billion parameters, we use LoRA (Low-Rank Adaptation) via the PEFT library. LoRA injects small trainable rank-decomposition matrices into the attention layers, leaving the base weights frozen.
LoRA Configuration
from peft import LoraConfig, get_peft_model, TaskType
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8,
lora_alpha=16,
lora_dropout=0.05,
target_modules=["q_proj", "v_proj"],
bias="none",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# trainable params: 2,228,224 || all params: 1,543,901,184 || trainable%: 0.1443
Only ~2.2 million of the model's 1.5 billion parameters are trained. This keeps memory usage low and training fast.
Training Arguments
from transformers import TrainingArguments
args = TrainingArguments(
output_dir="./outputs",
num_train_epochs=2,
per_device_train_batch_size=4,
gradient_accumulation_steps=4, # effective batch size = 16
learning_rate=2e-4,
fp16=True,
bf16=False,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
gradient_checkpointing=True,
optim="adamw_torch",
warmup_ratio=0.05,
lr_scheduler_type="cosine",
report_to="none",
)
A few things worth noting:
- fp16=True, bf16=False — We use standard fp16. In early experiments with bfloat16 we encountered NaN loss; switching to fp16 resolved it entirely.
- gradient_checkpointing=True — Trades compute for memory. Not strictly necessary on MI300X given the 192 GB VRAM, but good practice for reproducibility on smaller GPUs.
- gradient_accumulation_steps=4 — Effective batch size of 16 with a physical batch of 4.
- Cosine LR schedule with warmup — Smoother convergence than a flat schedule for short training runs.
The Full Training Loop
from transformers import DataCollatorForSeq2Seq, Trainer
collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
padding=True,
pad_to_multiple_of=8,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_ds,
eval_dataset=val_ds,
data_collator=collator,
)
trainer.train()
# Save adapter + tokenizer
model.save_pretrained("./outputs")
tokenizer.save_pretrained("./outputs")
After training, ./outputs contains the LoRA adapter weights — a few MB of files rather than a full multi-GB model checkpoint.
Inference
At inference time we load the base model, attach the LoRA adapter, and optionally merge the weights:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
tokenizer = AutoTokenizer.from_pretrained("./outputs", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
base_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-1.7B",
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
model = PeftModel.from_pretrained(base_model, "./outputs")
model.eval()
Generation uses greedy decoding (do_sample=False) with a repetition penalty to prevent the model from looping:
def generate(prompt, model, tokenizer):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=200,
do_sample=False,
temperature=1.0,
repetition_penalty=1.1,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
new_tokens = output[0][inputs["input_ids"].shape[-1]:]
return tokenizer.decode(new_tokens, skip_special_tokens=True)
Sample Output
Question: Which of the following is the first-line treatment for hypertensive emergency?
A) Oral amlodipine
B) IV labetalol or IV nitroprusside
C) Sublingual nifedipine
D) IM hydralazine
Model Output:
B) IV labetalol or IV nitroprusside
Explanation:
Intravenous labetalol (beta-blocker) or nitroprusside rapidly reduces blood
pressure in emergency settings. Oral agents act too slowly for hypertensive
emergencies requiring immediate BP control to prevent end-organ damage.
The model doesn't just output a letter — it explains *why*, which is what makes it clinically useful.
Loading from HuggingFace Hub
The fine-tuned adapter is publicly available. You can load it directly without cloning the repo:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-1.7B", trust_remote_code=True
)
base = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-1.7B",
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
model = PeftModel.from_pretrained(base, "HK2184/medqa-qwen3-lora")
model = model.merge_and_unload()
model.eval()
Challenges and Fixes
No AMD ROCm project is complete without a war story section. Here's what we ran into:
Challenge
Root Cause
Fix
NaN loss
Mixed precision instability
Switched from bfloat16 → fp16
GPU not detected
Missing ROCm env variables
Set ROCR_VISIBLE_DEVICES, HIP_VISIBLE_DEVICES, HSA_OVERRIDE_GFX_VERSION
bitsandbytes unsupported
No ROCm build of bitsandbytes
Dropped quantization entirely — MI300X has enough VRAM
Garbage inference output
Tokenizer padding misconfigured
Set pad_token = eos_token and fixed padding_side
Trainer eval errors
Transformers version mismatch
Pinned transformers>=4.40.0
The bitsandbytes issue deserves a note: on NVIDIA hardware, 4-bit quantization is often *required* to fit a model in memory. On MI300X with 192 GB HBM3, it's simply unnecessary. This is a genuine hardware advantage — cleaner training, no quantization artifacts.
Results
Metric
Value
Trainable parameters
~2.2M (0.15% of total)
Training time on MI300X
~5 minutes
Dataset size used
2,000 samples
Baseline MedMCQA accuracy
~45%
Framework
PyTorch + ROCm 6.1
Try It Yourself
No GPU? No problem. The live Gradio demo runs on HuggingFace Spaces (CPU inference):
👉 Live Demo on HuggingFace Spaces
Have AMD hardware? Clone the repo and run it natively:
git clone https://github.com/HK2184/MedQA-Medical-AI-on-AMD-ROCm.git
cd MedQA-Medical-AI-on-AMD-ROCm
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1
pip install transformers datasets peft accelerate trl gradio
python train.py # ~5 minutes
python infer.py # run sample questions
python app.py # launch Gradio UI
What's Next
This project proves the pipeline works. The next steps are about scaling and hardening it:
- Larger dataset — Train on the full MedMCQA corpus (~180k questions) and add PubMedQA
- Confidence scoring — Add calibrated confidence estimates alongside answers
- RAG integration — Ground answers in real-time medical literature retrieval
- Evaluation harness — Proper held-out accuracy benchmarking beyond the training split
Conclusion
MedQA shows that building a capable, explainable medical AI on open-source AMD hardware is not only possible — it's straightforward. The HuggingFace ecosystem's ROCm compatibility is genuinely good. The MI300X's memory headroom removes an entire category of engineering problems. And LoRA makes fine-tuning a 1.7B model a 5-minute job.
If you're building on AMD ROCm and hitting walls, the fixes above should save you hours. And if you're building medical AI, the emphasis on explanation over bare accuracy is worth taking seriously.
*Built for the AMD Developer Hackathon on lablab.ai · Powered by AMD ROCm + HuggingFace ecosystem*
*— Harikrishna Sivanand Iyer and Srijan Sivaram A
関連記事
今日のまとめ
AI日報で今日の重要ニュースをまとめ読み