FlashMemory DeepSeek-V4 リトリーバー(GitHub リポジトリ)
DeepSeek-V4 の圧縮スパースアテンション KV キャッシュを予測する軽量リトリーバー「FlashMemory」が公開され、オンデバイスメモリ使用量を約 10-15% に削減しながら性能を維持する技術的進展です。
キーポイント
KV キャッシュの動的選別メカニズム
デコードトークンの隠れ状態に基づき、次回の約 64 トークンがどの KV キャッシュチャンクにアテンションするかを予測し、重要度の低いチャンクを CPU/ディスクへオフロードします。
極限までのメモリ削減と性能維持
オンデバイス上の KV キャッシュ量を全体の 10-15% に抑えながら、フルアテンションベースラインと同等かそれ以上の推論精度を達成します。
DeepSeek-V4 専用アーキテクチャ
DeepSeek-V4 の Compressed-Sparse-Attention (CSA) 形式に特化しており、圧縮されたキー(uint8)とトークン位置情報を入力として利用します。
オープンソース実装の提供
GitHub リポジトリおよび Hugging Face でモデル重みが公開され、Python による簡易なデモや推論ループの実行が可能です。
圧縮キーのデータ構造
各チャンクは132バイト(128バイトのfloat8_e4m3量子化値+4バイトのfloat32再量子化スケール)で構成され、復元時にスケールを乗算して使用します。
CSA層におけるスコア計算フロー
クエリはLoRA投影、RMSNorm、RoPE(YaRN)、ハダマール積を経て生成され、キーは圧縮データから復元された後、両者の内積でスコアが計算されます。
Joint Checkpoint + Ensemble Strategy
The checkpoint contains three independent CSA layers (l10, l12, l20) whose per-layer sigmoid scores are ensembled via max or mean to make a single keep/drop decision for each chunk.
影響分析・編集コメントを表示
影響分析
この技術は、LLM の推論における最大のボトルネックの一つである KV キャッシュのメモリ使用量を劇的に削減する画期的なアプローチです。特に DeepSeek-V4 のような大規模モデルを、限られた GPU メモリ環境やエッジデバイスで実用的に運用するための重要なインフラ技術となり得ます。業界全体として、推論コストの低下とアクセシビリティ向上に寄与する可能性が高いです。
編集コメント
DeepSeek の最新アーキテクチャに特化したメモリ最適化技術がオープンソースで公開されたことは、推論効率化の分野において非常に注目すべき進展です。特に、性能を犠牲せずにメモリ使用量を 1/10 以下に抑えるという数値は、実運用における導入ハードルを大きく下げるものです。
軽量な検索器であり、DeepSeek-V4 の圧縮スパースアテンション (CSA) KV キャッシュをスパース化します。
デコードトークンの隠れ状態が与えられた場合、この検索器は次の約 64 トークンがどの CSA KV キャッシュチャンクに注意を向けるかを予測します。スコアが高い上位のチャンクのみが GPU に残存し、それ以外は CPU またはディスクへオフロードできます。下流の評価では、デバイス上に KV キャッシュの約 10–15% を保持したまま、フルアテンションベースラインと同等かそれ以上の性能を発揮します。
クイックスタート
pip install torch safetensors
モック入力によるデモ
python demo.py --ckpt weights/flashmemory_ds_v4.safetensors
玩具的なスパースデコードループ
python toy_flashmemory_inference.py --ckpt weights/flashmemory_ds_v4.safetensors
使用方法
from retriever import FlashMemoryRetriever
model = FlashMemoryRetriever.from_checkpoint(
"weights/flashmemory_ds_v4.safetensors", device="cuda"
)
hidden: [B, 4096] デコードトークンの隠れ状態
comp_k: [B, N, 132] uint8 形式の圧縮 CSA キー
positions: [B] int64 トークン位置
レイヤーごとのシグモイドスコア:{"l10": [B,N], "l12": [B,N], "l20": [B,N]}
per_layer = model(hidden, comp_k, positions)
層間アンサンブル (mode="max" または "mean")
scores = model.ensemble(hidden, comp_k, positions, mode="max") # [B, N]
Boolean keep mask
keep = model.select_topk(hidden, comp_k, positions, top_k=512) # top-K
keep = model.select_topk(hidden, comp_k, positions, threshold=0.5) # threshold
compressed_k format
Each chunk = HEAD_DIM + 4 = 132 uint8 bytes:
Bytes
Type
Meaning
[:128]
float8_e4m3
Quantized key values (量子化されたキー値)
[128:132]
float32
Per-chunk dequant scale (チャンクごとの非量子化スケール)
Dequant: fp8_values.view(float8_e4m3).float() * scale.**
See make_mock_compressed_k() in demo.py.
Architecture
Per CSA layer, scores are computed as:
hidden [B, 4096]
→ wq_a (4096 → Q_LORA_RANK)
→ RMSNorm (q_norm_weight, eps=1e-6)
→ wq_b (Q_LORA_RANK → N_HEADS * HEAD_DIM)
→ reshape [B, N_HEADS, HEAD_DIM]
→ RoPE (YaRN, last ROPE_DIM=64 dims, base=160000)
→ Hadamard (normalized Walsh-Hadamard)
→ q [B, N_HEADS, HEAD_DIM]
hidden [B, 4096]
→ weights_proj (4096 → N_HEADS)
→ × weight_scale (= HEAD_DIM^-0.5 * N_HEADS^-0.5)
→ fused_w [B, N_HEADS]
compressed_k [B, N, HEAD_DIM + 4] (uint8)
→ bytes[:HEAD_DIM] viewed as float8_e4m3 → dequant
→ × bytes[HEAD_DIM:] viewed as float32 → k [B, N, HEAD_DIM]
score = sigmoid( sum_heads( relu(k @ q^T) * fused_w ) ) in [0, 1]
Joint checkpoint + ensemble
The checkpoint holds three independent CSA layers (l10, l12, l20),
each with its own weights. At inference time per-layer sigmoid scores are
ensembled per chunk — max (union, default) or mean — to produce a
single keep/drop decision.
ハイパーパラメータ
パラメータ
値
N_HEADS
128
HEAD_DIM
128
Q_LORA_RANK
2048
ROPE_DIM
64 (最後の 64 次元)
ROPE_BASE
160000 (YaRN)
ROPE_FACTOR
16
ROPE_ORIGINAL_SEQ_LEN
65536
ROPE_BETA_FAST
32
ROPE_BETA_SLOW
1
RMS_NORM_EPS
1e-6
トイ推論リファレンス (toy_flashmemory_inference.py)
これは、デコード中にリトリーバーがどのようにメモリ想起を駆動するかを示す自己完結型の例であり、DeepSeek-V4-FlashMemory 内部で実際に使用される制御フローを表しています。
推論フロー
┌──────────┐ 圧縮・保存 ┌────────────────────────────┐
│ PREFILL │ 履歴 K/V │ CSA KV キャッシュ(メモリ)│
│ (dense │ ───────────────────► │ N 個の圧縮チャンク、 │
│ attn) │ │ 各 = [132] uint8 fp8-K │
└────┬─────┘ └──────────────┬─────────────┘
│ 最後の隠れ状態 │ 64 ステップごとにスコアリング
▼ │
┌──────────────────────── DECODE LOOP ──────────┼──────────────────────────┐
│ 各デコードステップ t について: │ │
│ hidden = toy_decoder.step(token, keep_mask) │ (スパースメモリアテンション)│
│ │ │
│ RETRIEVAL_INTERVAL (= 64) ステップごと: ▼ │
│ scores[N] = retriever.ensemble(hidden, compressed_k, pos) │
│ keep_mask[N] = scores の上位 K (またはシグモイド > しきい値) │
│ -> 次回の 64 ステップで未選択チャンクを -inf にマスク │
└──────────────────────────────────────────────────────────────────────────┘
- Prefill(密)。短いプロンプトは密なメモリアテンションを通ります。その最後の隠れ状態が最初の検索サイクルの種となります。
- Decode loop(デコードループ)。Toy decoder は各ステップで [B, 4096] の隠れ状態を生成します。
- Retrieval cycle (every 64 steps). The real FlashMemoryRetriever scores all N compressed-K chunks, ensembles per-layer scores, selects keep chunks.
- Sparse attention. Unselected chunks' attention logits are set to -inf.
What this simulates
- This toy does NOT perform real CPU↔GPU KV-cache transfer. The swap engine is internal FlashMemory infrastructure and is not included.
- We simulate memory recall by masking attention logits to -inf. A masked chunk contributes nothing to attention — the same effect as not loading its KV.
- The purpose is to make the decode-time control flow concrete.
What it is / is NOT
IS
IS NOT
Minimal torch-only illustration of memory recall
A runnable DeepSeek-V4
Uses the real retriever weights & scoring math
Production KV swap engine
Pedagogical: shows the control flow
Meaningful text generation
The production version depends on the internal sglang + DeepSeek-V4 CSA framework (native FP8 indexer, real compressed KV-cache, attention-sink, threshold fallback, per-request routing, actual KV swap) and cannot be released.
Downstream evaluation
FlashMemory DS-V4 beats or ties the full-attention baseline on reasoning-heavy long-context tasks while keeping only ~10–15% of CSA KV cache on-device:
Task
Context
vs. Full-Attn
KV Saved
RULER (64k–512k)
64K–512K
−1 ~ +2 pp
~80–90%
LongMemEval-s
125K
±1 pp
~86%
LongMemEval-m
500K
±1 pp
~91%
LongBench V2
46K–493K
+1 ~ +2 pp
~73–90%
MRCR (needle)
274K
needs fallback
~86%
精密なニードル検索タスク(MRCR)には、推論層における追加のしきい値フォールバック機能が必要です。これはスタンドアローン版には含まれていません。
Files
File
Purpose
retriever.py
FlashMemoryRetriever モデル + RoPE/Hadamard + FP8 非量子化
demo.py
モック入力を用いた最小限のデモ
toy_flashmemory_inference.py
玩具的なスパース・デコードループ
weights/flashmemory_ds_v4.safetensors
学習済み重み(約 510 MB、Hugging Face で利用可能)
requirements.txt
依存関係
License
MIT
Citation
FlashMemory を研究で使用する場合は、以下を引用してください:
@article{wang2026flashmemory,
title = {FlashMemory-DeepSeek-V4: Lightning Index Ultra-Long Context via Lookahead Sparse Attention},
author = {Yan Wang and Qifan Zhang and Jiachen Yu and Tian Liang and Dongyang Ma and
Xiang Hu and Zibo Lin and Chunyang Li and Zhichao Wang and Jia Li and
Yujiu Yang and Haitao Mi and Dong Yu},
year = {2026},
journal = {arXiv preprint arXiv:2606.09079},
url = {https://huggingface.co/papers/2606.09079},
}
原文を表示
A lightweight retriever that sparsifies DeepSeek-V4 Compressed-Sparse-Attention (CSA) KV-cache.
Given the hidden state of a decode token, the retriever predicts which CSA
KV-cache chunks the next ~64 tokens will attend to. Only the top-scoring chunks
stay resident on the GPU; the rest can be offloaded to CPU/disk. In downstream
evaluation it matches or beats the full-attention baseline while keeping **~10–15%
of the KV cache** on-device.
Quick start
pip install torch safetensors
# Demo with mock inputs
python demo.py --ckpt weights/flashmemory_ds_v4.safetensors
# Toy sparse-decode loop
python toy_flashmemory_inference.py --ckpt weights/flashmemory_ds_v4.safetensorsUsage
from retriever import FlashMemoryRetriever
model = FlashMemoryRetriever.from_checkpoint(
"weights/flashmemory_ds_v4.safetensors", device="cuda"
)
# hidden: [B, 4096] decode-token hidden state
# comp_k: [B, N, 132] uint8 compressed CSA keys
# positions: [B] int64 token positions
# Per-layer sigmoid scores: {"l10": [B,N], "l12": [B,N], "l20": [B,N]}
per_layer = model(hidden, comp_k, positions)
# Cross-layer ensemble (mode="max" or "mean")
scores = model.ensemble(hidden, comp_k, positions, mode="max") # [B, N]
# Boolean keep mask
keep = model.select_topk(hidden, comp_k, positions, top_k=512) # top-K
keep = model.select_topk(hidden, comp_k, positions, threshold=0.5) # thresholdcompressed_k format
Each chunk = HEAD_DIM + 4 = 132 uint8 bytes:
Bytes
Type
Meaning
[:128]
float8_e4m3
Quantized key values
[128:132]
float32
Per-chunk dequant scale
Dequant: fp8_values.view(float8_e4m3).float() * scale.**
See make_mock_compressed_k() in demo.py.
Architecture
Per CSA layer, scores are computed as:
hidden [B, 4096]
→ wq_a (4096 → Q_LORA_RANK)
→ RMSNorm (q_norm_weight, eps=1e-6)
→ wq_b (Q_LORA_RANK → N_HEADS * HEAD_DIM)
→ reshape [B, N_HEADS, HEAD_DIM]
→ RoPE (YaRN, last ROPE_DIM=64 dims, base=160000)
→ Hadamard (normalized Walsh-Hadamard)
→ q [B, N_HEADS, HEAD_DIM]
hidden [B, 4096]
→ weights_proj (4096 → N_HEADS)
→ × weight_scale (= HEAD_DIM^-0.5 * N_HEADS^-0.5)
→ fused_w [B, N_HEADS]
compressed_k [B, N, HEAD_DIM + 4] (uint8)
→ bytes[:HEAD_DIM] viewed as float8_e4m3 → dequant
→ × bytes[HEAD_DIM:] viewed as float32 → k [B, N, HEAD_DIM]
score = sigmoid( sum_heads( relu(k @ q^T) * fused_w ) ) in [0, 1]
Joint checkpoint + ensemble
The checkpoint holds three independent CSA layers** (l10, l12, l20),
each with its own weights. At inference time per-layer sigmoid scores are
ensembled per chunk — max (union, default) or mean — to produce a
single keep/drop decision.
Hyperparameters
Param
Value
N_HEADS
128
HEAD_DIM
128
Q_LORA_RANK
2048
ROPE_DIM
64 (last 64 dims)
ROPE_BASE
160000 (YaRN)
ROPE_FACTOR
16
ROPE_ORIGINAL_SEQ_LEN
65536
ROPE_BETA_FAST
32
ROPE_BETA_SLOW
1
RMS_NORM_EPS
1e-6
Toy inference reference (toy_flashmemory_inference.py)
A self-contained illustration of how the retriever drives memory recall during
decode — the actual control flow used inside DeepSeek-V4-FlashMemory.
Inference flow
┌──────────┐ compress & store ┌────────────────────────────┐
│ PREFILL │ historical K/V │ CSA KV-cache (the memory) │
│ (dense │ ───────────────────► │ N compressed chunks, │
│ attn) │ │ each = [132] uint8 fp8-K │
└────┬─────┘ └──────────────┬─────────────┘
│ last hidden state │ scored every 64 steps
▼ │
┌──────────────────────── DECODE LOOP ──────────┼──────────────────────────┐
│ for each decode step t: │ │
│ hidden = toy_decoder.step(token, keep_mask) │ (sparse memory attn) │
│ │ │
│ every RETRIEVAL_INTERVAL (= 64) steps: ▼ │
│ scores[N] = retriever.ensemble(hidden, compressed_k, pos) │
│ keep_mask[N] = top-K (or sigmoid>thresh) of scores │
│ -> unselected chunks masked to -inf in next 64 steps │
└──────────────────────────────────────────────────────────────────────────┘
- Prefill (dense). Short prompt runs through dense memory attention. Its
last hidden state seeds the first retrieval cycle.
- Decode loop. Toy decoder produces a [B, 4096] hidden state each step.
- Retrieval cycle (every 64 steps). The real FlashMemoryRetriever scores
all N compressed-K chunks, ensembles per-layer scores, selects keep chunks.
- Sparse attention. Unselected chunks' attention logits are set to -inf.
What this simulates
- This toy does NOT perform real CPU↔GPU KV-cache transfer. The swap engine
is internal FlashMemory infrastructure and is not included.
- We simulate memory recall by masking attention logits to -inf. A masked
chunk contributes nothing to attention — the same effect as not loading its KV.
- The purpose is to make the decode-time control flow concrete.
What it is / is NOT
IS
IS NOT
Minimal torch-only illustration of memory recall
A runnable DeepSeek-V4
Uses the real retriever weights & scoring math
Production KV swap engine
Pedagogical: shows the control flow
Meaningful text generation
The production version depends on the internal sglang + DeepSeek-V4 CSA framework
(native FP8 indexer, real compressed KV-cache, attention-sink, threshold fallback,
per-request routing, actual KV swap) and cannot be released.
Downstream evaluation
FlashMemory DS-V4 beats or ties the full-attention baseline on reasoning-heavy
long-context tasks while keeping only ~10–15% of CSA KV cache on-device:
Task
Context
vs. Full-Attn
KV Saved
RULER (64k–512k)
64K–512K
−1 ~ +2 pp
~80–90%
LongMemEval-s
125K
±1 pp
~86%
LongMemEval-m
500K
±1 pp
~91%
LongBench V2
46K–493K
+1 ~ +2 pp
~73–90%
MRCR (needle)
274K
needs fallback
~86%
Precise needle-retrieval tasks (MRCR) require an additional threshold-fallback
in the serving layer — this is not part of the standalone release.
Files
File
Purpose
retriever.py
FlashMemoryRetriever model + RoPE/Hadamard + FP8 dequant
demo.py
Minimal demo with mock inputs
toy_flashmemory_inference.py
Toy sparse-decode loop
weights/flashmemory_ds_v4.safetensors
Trained weights (~510 MB, on Hugging Face)
requirements.txt
Dependencies
License
MIT
Citation
If you use FlashMemory in your research, please cite:
@article{wang2026flashmemory,
title = {FlashMemory-DeepSeek-V4: Lightning Index Ultra-Long Context via Lookahead Sparse Attention},
author = {Yan Wang and Qifan Zhang and Jiachen Yu and Tian Liang and Dongyang Ma and
Xiang Hu and Zibo Lin and Chunyang Li and Zhichao Wang and Jia Li and
Yujiu Yang and Haitao Mi and Dong Yu},
year = {2026},
journal = {arXiv preprint arXiv:2606.09079},
url = {https://huggingface.co/papers/2606.09079},
}関連記事
Cohere が開発者向けコード生成モデル「North Mini Code」を発表:30B パラメータの MoE アーキテクチャで 3B アクティブ
Cohere AI チームは、ソフトウェアエンジニア向けのオープンウェイトコード生成モデル「North Mini Code」を公開した。このモデルは総パラメータ数 30B の混合専門家(MoE)アーキテクチャを採用し、トークン処理時に 3B のパラメータのみが活性化するように設計されている。
DiffusionGemma:Google の高速テキスト生成モデルが再登場
Google は昨年実験的に公開した Gemini Diffusion モデルの研究を再開し、DiffusionGemma として再発表しました。このモデルは以前 1 秒間に 857 トークンの生成速度を記録しており、テキスト生成の高速化に寄与する技術です。
DiffusionGemma:テキスト生成が4倍高速化
Google DeepMind は、新しい手法「DiffusionGemma」を発表し、テキスト生成の速度を従来の4倍に向上させることに成功しました。
今日のまとめ
AI日報で今日の重要ニュースをまとめ読み