RetNet入門
Spiral.AI Tech Blog は、Transformer の計算リソース問題を解決し、RNN の推論効率と Transformer の並列学習を両立する新アーキテクチャ「RetNet」の仕組みと実装詳細を解説している。
キーポイント
RetNet の核心:Retention 機構
RNN の回帰形式による省メモリ推論能力と、Transformer の行列計算による並列学習効率を両立する「Retention」機構を採用し、長文処理における計算資源の課題を克服している。
Chunk-wise 回帰表現による最適化
入力をチャンク単位に分解することで、長い系列データにおける学習時の計算速度とメモリ消費量を大幅に改善する手法が提案されている。
Multi-Scale Retention (MSR) の実装
Transformer の MHA に倣い複数のヘッドを保持しつつ、各ヘッドで異なる減衰率を設定するマルチスケール化とゲート機能により表現力を高めている。
推論効率とメモリ使用量の優位性
RetNet は回帰形式により推論コストを O(1) に抑えられ、系列長が長くても GPU メモリ消費量が一定に保たれるため、Transformer のような線形増加を防ぐ。
学習の並列化と高速スループット
従来の RNN と異なり学習時の並列化が可能であり、系列長が増大してもスループットがほとんど低下しない高性能を実現している。
非 Transformer アーキテクチャの動向
RetNet の他にも LEX Transformer、Linear Attention、RWKV、S4、Mamba などが提案されており、計算資源が限られた環境でも高性能な処理を目指す動きがある。
影響分析・編集コメントを表示
影響分析
本記事は、Transformer に代わるあるいは補完するアーキテクチャとしての RetNet の技術的優位性を明確に示しており、長文コンテキストやリアルタイム推論を必要とする次世代 AI システムの設計指針となる重要な知見を提供しています。特に計算効率と性能の両立という課題に対する具体的な解決策は、業界全体におけるモデル選定や開発戦略に影響を与える可能性があります。
編集コメント
Transformer の限界を打破する RetNet は、今後の大規模言語モデルの効率化において極めて重要な技術的転換点となるでしょう。本記事は複雑な数式解説を含みつつも、その実用価値を明確に伝えている良質な入門資料です。
architecture NLP transformer Generative AI retention techどうも、スプラで見かけた「ふくやまはさまる」のネーミングセンスに感服した@ksterxです。 Claude OpusさんやChatGPT大先生にこの意味をご説明願ったら、頓珍漢な答えが返ってきました。この辺が解せるようになったらAGIであるとかなんとか。
前置きはさておき、近年、自然言語処理(NLP)の分野は、ディープラーニングの進化とともに大きく変化しています。初期の深層学習アプローチではリカレントニューラルネットワーク(RNN)に焦点を当てていましたが、2017年の「Attention Is All You Need」という論文の登場により、それまでのアーキテクチャとは全く異なるTransformerアーキテクチャが覇権を握ることになります。しかし、Transformerアーキテクチャにも問題点があり、特に長い文書を処理する際に計算資源が膨大になるという課題がありました。この課題に対処するために、様々な非Transformer系のアーキテクチャが提案されています。
このブログでは、Transformerの問題点を克服しつつ、その強みを活かした比較的新しいアーキテクチャ「Retentive Network(RetNet)」について掘り下げていきます。
基本的に、断りが無ければ原論文から引用しています。
RetNetは、Transformerの強力な特性を保ちつつ、上述の問題点に対処する新しいアーキテクチャです。RNNの時系列データを高速に推論する能力と、Transformerの並列計算可能なアテンションによる全トークン間の関係性を捉える能力を組み合わせることで、効率的な学習と推論(+パフォーマンス)を可能にします。
この論文のコアとなる提案 Retention機構について見ていきましょう。入力Xから得られたcontent-awareなQ、Kと、関数v(n)を用いて
出力o_nが回帰形式で表されるものとします。ここでs_nを中間状態と呼ぶことにしましょう。
詳細は原論文に譲りますが、この中間状態を介した回帰形式から出発して、適当な仮定の下で式変形を行うと最終的に出力を行列形式で書き下すことができます。
これの何が嬉しいかと言うと、後に詳述しますが推論時には回帰的で省メモリな計算をしつつ、学習時には行列計算による高効率な並列計算が可能になるという、RNNとTransformerモデルのいいとこ取りができるということです。
並列表現について詳しく見てみましょう。
Retention
Attention (Source: All You Need Is Attentionより)
上がRetention機構で、下がAttention機構です。共役行列やSoftmaxなどの違いはあれど、アーキテクチャに関係なく、入力に対して動的に変化するQ、K、Vに関しては似通った形をしていることが見て取れます。
Source: Attention Is All You Need
Attention機構で提案されたような、クエリとキー間での演算をRetentionでも行っており、トークン間の関係性を捉えるような学習が可能であることを示唆していると考えられます。
次に、回帰表現について見てみましょう。
こちらは非常に単純で、キーバリュー演算と、前時刻の中間状態・減衰率の積の和が現時刻の中間状態になります。Retentionの計算結果はクエリとこの中間状態の積です。
Chunk-wise回帰表現というものも提案されていて、入力をB個のチャンクに分解することで、チャンク単位での並列表現と回帰表現を実現することで、特に長い系列での学習時の効率が、計算速度やメモリ消費量の観点で改善されています。
Multi-Scale Retention (MSR)
TransformerではMHA(Multi-Head Attention)のようにAttention機構を複数保持していましたが、RetNetでも同様にRetention機構を複数保持したアーキテクチャになっています。ただし、RetNetでは各headで用いられる減衰率が異なるように設定されます(マルチスケール化)。Gated Multi-Scale Retentionの出力としては、各headのretentionの正規化下出力と、ゲーティングした残差接続との間で要素積を取り、射影変換したものとなります。
l(エル)層目のデコーダーレイヤーは上のように書けます。MSRとFFNから成り、これを複数(例えば32層)積み重ねた構造をRetNetはしています。
次にRetNetが実際にどのような性能を持っているかを見てみましょう。
学習の並列化については、Transformer、RetNetともにできます。これは、先の並列表現のおかげですね。従来のRNNでは並列化ができないため学習効率が良くないです。
推論コストは、TransformerはAttention計算で最新のトークンのクエリと、過去と自身のキーに対する計算をする必要があるので、O(N)になります。対してRetNetやRNNは、回帰形式で計算することで、O(1)での推論が可能です。
キャッシュは、TransformerがKVキャッシュとしてO(N^2)、位置時刻前までのKVとしてO(N)を必要とします。実際実験でも、Transformerが系列長に対して線形にGPUメモリが増大するのに対し、RetNetでは系列長が増大してもGPUメモリは一定の値を取っていることが分かります。論文では、6.7Bモデル自体の重みが97%を占めており、キャッシュされる前時刻の状態等で必要になるGPUメモリは3%程度であると主張しています。
RetNetは系列長が増大しても、ほとんどスループットに変化はありません(そもそものスループットがすごいですね()
他の非Transformerアーキテクチャについて
Attention計算のデメリットを克服するためにRetNet以外にも様々な手法が提案されています。
ここでは、非Transformerアーキテクチャと便宜的に呼んではいますが、むしろTransformerの派生モデルというべきかもしれません。
RetNetと近しいモデルにLEX Transformerがあげられます。概要の方で、回帰表現と並列表現の同等性について触れましたが、その途中で出てくる表現が上の式です。これはxPosそのものであり、LEX TransformerではxPosを適用したQ、Kにattentionに適用しています。
他にも、Linear AttentionやRWKV、S4、Mambaと言ったものが非Transformerとしてあげられます。
その他にも、Linear Attention、RWKV、S4、Mambaなどのアーキテクチャが非Transformerアーキテクチャとして提案されています。これらのモデルは、学習効率や推論効率を向上させることを目指しており、計算資源が限られた環境でも高性能な処理が可能です。
現在、Transformer系のモデルが世界的に主流となっていますが、明示的に履歴を保持するAttentionは学習効率や推論効率が悪いことが課題となっており、計算資源が乏しい環境ではOpenAIやAnthropicのようなAPIを使わざるを得ない状況です。このような背景の中、RetNetのような新しいアーキテクチャが、NLPの分野に新たな可能性をもたらすかもしれません。
LLMエンジニアを募集しています!!
SpiralAIでは、生成AI×エンタメをテーマに様々なプロジェクトが立ち上がっています!最近では、マルチターン会話モデルを発表したりしています。 是非、ご興味があれば@ksterxや採用ページまでご連絡ください〜
GitHubで編集を提案 SpiralAIテックブログPublication実在する芸能人との会話ができる日本初のAIサービス「NaomiAI」やカスタムChatGPTを作れる「Spiralbot」を提供するSpiralAI株式会社のテックブログです。
原文を表示
architecture
NLP
transformer
Generative AI
retention techどうも、スプラで見かけた「ふくやまはさまる」のネーミングセンスに感服した@ksterxです。 Claude OpusさんやChatGPT大先生にこの意味をご説明願ったら、頓珍漢な答えが返ってきました。この辺が解せるようになったらAGIであるとかなんとか。
前置きはさておき、近年、自然言語処理(NLP)の分野は、ディープラーニングの進化とともに大きく変化しています。初期の深層学習アプローチではリカレントニューラルネットワーク(RNN)に焦点を当てていましたが、2017年の「Attention Is All You Need」という論文の登場により、それまでのアーキテクチャとは全く異なるTransformerアーキテクチャが覇権を握ることになります。しかし、Transformerアーキテクチャにも問題点があり、特に長い文書を処理する際に計算資源が膨大になるという課題がありました。この課題に対処するために、様々な非Transformer系のアーキテクチャが提案されています。
このブログでは、Transformerの問題点を克服しつつ、その強みを活かした比較的新しいアーキテクチャ「Retentive Network(RetNet)」について掘り下げていきます。
基本的に、断りが無ければ原論文から引用しています。
RetNetは、Transformerの強力な特性を保ちつつ、上述の問題点に対処する新しいアーキテクチャです。RNNの時系列データを高速に推論する能力と、Transformerの並列計算可能なアテンションによる全トークン間の関係性を捉える能力を組み合わせることで、効率的な学習と推論(+パフォーマンス)を可能にします。 
この論文のコアとなる提案 Retention機構について見ていきましょう。入力Xから得られたcontent-awareなQ、Kと、関数v(n)を用いて 
出力o_nが回帰形式で表されるものとします。ここでs_nを中間状態と呼ぶことにしましょう。
詳細は原論文に譲りますが、この中間状態を介した回帰形式から出発して、適当な仮定の下で式変形を行うと最終的に出力を行列形式で書き下すことができます。 
これの何が嬉しいかと言うと、後に詳述しますが推論時には回帰的で省メモリな計算をしつつ、学習時には行列計算による高効率な並列計算が可能になるという、RNNとTransformerモデルのいいとこ取りができるということです。
並列表現について詳しく見てみましょう。
Retention
Attention (Source: All You Need Is Attentionより)
上がRetention機構で、下がAttention機構です。共役行列やSoftmaxなどの違いはあれど、アーキテクチャに関係なく、入力に対して動的に変化するQ、K、Vに関しては似通った形をしていることが見て取れます。

Source: Attention Is All You Need
Attention機構で提案されたような、クエリとキー間での演算をRetentionでも行っており、トークン間の関係性を捉えるような学習が可能であることを示唆していると考えられます。
次に、回帰表現について見てみましょう。

こちらは非常に単純で、キーバリュー演算と、前時刻の中間状態・減衰率の積の和が現時刻の中間状態になります。Retentionの計算結果はクエリとこの中間状態の積です。

Chunk-wise回帰表現というものも提案されていて、入力をB個のチャンクに分解することで、チャンク単位での並列表現と回帰表現を実現することで、特に長い系列での学習時の効率が、計算速度やメモリ消費量の観点で改善されています。
Multi-Scale Retention (MSR)

TransformerではMHA(Multi-Head Attention)のようにAttention機構を複数保持していましたが、RetNetでも同様にRetention機構を複数保持したアーキテクチャになっています。ただし、RetNetでは各headで用いられる減衰率が異なるように設定されます(マルチスケール化)。Gated Multi-Scale Retentionの出力としては、各headのretentionの正規化下出力と、ゲーティングした残差接続との間で要素積を取り、射影変換したものとなります。

l(エル)層目のデコーダーレイヤーは上のように書けます。MSRとFFNから成り、これを複数(例えば32層)積み重ねた構造をRetNetはしています。
次にRetNetが実際にどのような性能を持っているかを見てみましょう。

学習の並列化については、Transformer、RetNetともにできます。これは、先の並列表現のおかげですね。従来のRNNでは並列化ができないため学習効率が良くないです。
推論コストは、TransformerはAttention計算で最新のトークンのクエリと、過去と自身のキーに対する計算をする必要があるので、O(N)になります。対してRetNetやRNNは、回帰形式で計算することで、O(1)での推論が可能です。
キャッシュは、TransformerがKVキャッシュとしてO(N^2)、位置時刻前までのKVとしてO(N)を必要とします。実際実験でも、Transformerが系列長に対して線形にGPUメモリが増大するのに対し、RetNetでは系列長が増大してもGPUメモリは一定の値を取っていることが分かります。論文では、6.7Bモデル自体の重みが97%を占めており、キャッシュされる前時刻の状態等で必要になるGPUメモリは3%程度であると主張しています。

RetNetは系列長が増大しても、ほとんどスループットに変化はありません(そもそものスループットがすごいですね()

他の非Transformerアーキテクチャについて
Attention計算のデメリットを克服するためにRetNet以外にも様々な手法が提案されています。
ここでは、非Transformerアーキテクチャと便宜的に呼んではいますが、むしろTransformerの派生モデルというべきかもしれません。

RetNetと近しいモデルにLEX Transformerがあげられます。概要の方で、回帰表現と並列表現の同等性について触れましたが、その途中で出てくる表現が上の式です。これはxPosそのものであり、LEX TransformerではxPosを適用したQ、Kにattentionに適用しています。
他にも、Linear AttentionやRWKV、S4、Mambaと言ったものが非Transformerとしてあげられます。
その他にも、Linear Attention、RWKV、S4、Mambaなどのアーキテクチャが非Transformerアーキテクチャとして提案されています。これらのモデルは、学習効率や推論効率を向上させることを目指しており、計算資源が限られた環境でも高性能な処理が可能です。
現在、Transformer系のモデルが世界的に主流となっていますが、明示的に履歴を保持するAttentionは学習効率や推論効率が悪いことが課題となっており、計算資源が乏しい環境ではOpenAIやAnthropicのようなAPIを使わざるを得ない状況です。このような背景の中、RetNetのような新しいアーキテクチャが、NLPの分野に新たな可能性をもたらすかもしれません。
LLMエンジニアを募集しています!!
SpiralAIでは、生成AI×エンタメをテーマに様々なプロジェクトが立ち上がっています!最近では、マルチターン会話モデルを発表したりしています。 是非、ご興味があれば@ksterxや採用ページまでご連絡ください〜
GitHubで編集を提案
SpiralAIテックブログPublication実在する芸能人との会話ができる日本初のAIサービス「NaomiAI」やカスタムChatGPTを作れる「Spiralbot」を提供するSpiralAI株式会社のテックブログです。

関連記事
今日のまとめ
AI日報で今日の重要ニュースをまとめ読み