1ペタバイトのデータセットで機械学習する / WebDataset入門
TURING Tech Blog は、ペタバイト規模の自動運転データ処理における I/O ボトルネックを解決する WebDataset の実装と、大規模分散学習への適用事例を紹介している。
キーポイント
大規模データセットの I/O ボトルネック
GPU 性能だけでなく、ネットワーク帯域やストレージのディスク I/O が大規模機械学習のボトルネックとなりやすく、従来のファイルシステムでは均一なランダムアクセスやシャッフル処理が非効率である。
WebDataset の統一アーキテクチャ
オンメモリからクラウドストレージ(S3)、高速分散ファイルシステム(Lustre/GPFS)まで、異なるストレージ環境に対して統一的かつ効率的なデータローダーを提供するライブラリ。
自動運転分野での実証規模
TURING は 50,000 時間の走行データ(約 1 ペタバイト)を収集し、これを PyTorch と WebDataset を組み合わせて処理する技術検証を進めている。
開発から本番へのスケーラビリティ
小規模なテスト環境からペタスケールの学習環境へ移行する際、データローダーの実装やチューニングを最小限に抑え、コードの再利用性を高めるアプローチ。
POSIX tar形式の活用と拡張性
WebDataset は特殊な変換を不要とし、標準的な POSIX tar ファイルをシャーディングしてストリーミング読み込み可能にするため、小規模データでの開発から大規模データへのスケーリングが容易です。
PyTorch 連携とデータパイプライン
PyTorch の IterableDataset を継承し、decode や map、compose を用いて画像のデコード、ラベル抽出、ランダム切り出し(データ拡張)などの前処理を柔軟に定義できます。
外部依存なしの軽量設計
Python で書かれた独立したライブラリであり、シリアライズ形式の変換が不要なため、既存の tar ファイルをそのまま利用して高速な I/O を実現します。
影響分析・編集コメントを表示
影響分析
この記事は、大規模深層学習の実践において、ハードウェア性能だけでなくデータパイプラインの最適化がいかに決定的かを示す重要な事例です。特に自動運転分野のようなペタバイト級データを扱う業界において、WebDataset のような専用ライブラリを採用することで、インフラ変更コストを下げつつスケーラビリティを確保する標準的なアプローチを提示しています。
編集コメント
ハードウェアの性能向上だけでは解決できないデータ I/O の課題に対し、ソフトウェア層での工夫(WebDataset)で対応する実務的な視点が光る記事です。大規模モデル開発におけるインフラ設計の参考になります。
1ペタバイトのデータセットで機械学習する / WebDataset入門
深層学習をする上で、最も大切なマシンスペックを聞かれたら何と答えますか? GPUのTensor性能、VRAM、GPUの数、CPU性能、メモリ、… 問題によって正解は異なりますね。
しかし、特に大規模なデータセットで機械学習する場合では、しばしばネットワーク帯域とストレージシステムのディスクI/Oによって制限されます。この記事ではそのような課題に対して、学習側でどのようにデータを扱うかを見ていきたいと思います。
こんにちは、TURING MLチームです。TURINGはEnd-to-Endな深層学習モデルでLv5完全自動運転車の開発を目指す会社です。
私たちは自動運転モデルを動かすため、可視域のカメラセンサによる画像で学習し、カメラ映像のみから車体の操作や経路選択、安全性の判断を行わせています。(実際の車を動かす事例はこちらの記事をご覧ください。)
そのため、機械学習のために大量の画像データが必要になってきます。TURINGでは2022年に500時間、2023年に50,000時間の公道上の走行データをカメラによって収集する計画を立てています。50,000時間、というとピンと来ないかと思いますが、仮に平均時速25kmだとすると、合計で125万kmになります。これは日本の道路総延長(=128万km) に匹敵する距離です。(もちろん、日本の全ての道路を走行するわけではなく、都市部を中心に同じ道をさまざまな角度・条件で撮影していくことになります。)
重要なことは、データが大量の動画という形で取得されるという点です。動画はエンコードされて圧縮されている状態でも~1GB/時間、複数カメラで機械学習用のテンソルに成形すると数十GB/時間程度にもなります。そのため、50,000時間の走行でおよそ1ペタバイト程度のデータとなります。
TURINGはこのような大規模な機械学習に向け、ストレージシステムの検討やI/Oパフォーマンスの測定、そして100基単位でのGPUによる並列分散学習のための技術検証を進めています。この記事では、開発段階の小規模なデータセットからペタバイトスケールの機械学習まで対応できるデータローダーの仕組みの一例として、PyTorchで学習するケースを紹介していきたいと思います。
- 学習データをロードする5つのシナリオ
機械学習でデータセットを計算サーバに転送するには5つの方法があります。
(1) オンメモリ (2) ローカルディスク (3) Web/ファイルサーバー (4) クラウドストレージ (5) 高速分散ストレージシステム
全てのデータがメモリ上に収まるサイズのデータセットであれば、多くのケースで特別なデータハンドリングは必要ありません。一度データを読み出せば高速で処理が可能です。一方、画像分類等の大きな訓練用データセットは、しばしばメモリサイズを超過します。そのようなケースでは、ディスク上に配置されたファイルを逐次読み込む必要があります。
インターネット上に公開されているデータセットは、前もって全てダウンロードすればローカルディスクと同様に扱えますが、ネットワーク帯域によっては多くの時間が要求されます。
また、Amazon S3など、クラウドストレージ上に保存されているデータを読み出す場合、ファイルシステムとしてマウントしたり、公式のS3 IO datapipesを使ってデータパイプラインを作成します。さらに大規模な環境では、LustreやGPFSなど高速な並列分散ファイルシステムが採用されることがありますが、このような場合でもやはりネットワーク帯域やファイルシステムのパフォーマンスに影響を受けます。
一方、どのようなストレージシステムを利用するにせよ、機械学習のデータセットへのアクセスは共通して下記のような特徴を持ちます。このようなデータアクセスは機械学習特有なもので、既存のファイルシステムとうまく適合しない可能性があります。
均一にランダムなアクセスパターンを持つ
多くの(ときには数十億の)ファイルで構成されている
学習前にデータをシャッフルや前処理、データ拡張する必要がある
小さなデータセットで開発/テストをしてから、大きなデータセットにスケールアップさせていく場合、ストレージシステムを変更する状況がしばしば生じます。学習データのスケールにあわせ、システムごとにデータローダーの実装やパフォーマンスチューニングをするのは大変にコストがかかります。そのため、ファイル形式やロードのためのコードをできるだけ変更せずに対応したいわけですが、どうしたらよいでしょうか?
- WebDatasetとは
PyTorchに対するWebDatasetライブラリは、このようなデータ読み込みの問題を解決し、(1)~(5)までの全てのシナリオに対してペタスケールまでの統一的で効率のよいアクセスを提供してくれます。
WebDataset is an ideal solution for training on petascale datasets kept on high performance distributed data stores like AIStore, AWS/S3, and Google Cloud.
WebDataset also is very useful for such smaller datasets, and it can easily be used for developing and testing on small datasets and then scaling up to large datasets by simply using more shards.
WebDatasetはBigDat2019で発表されたHigh Performance I/O For Large Scale Deep Learningで示された大規模な深層学習のためのデータセット機構とその実装です。WebDatasetは任意のストレージシステムにデータを数十~数百MBごとにシャーディング(分割)して配置し、シーケンシャルに読み込むことででストリーミングによるアクセスを可能にしています。
WebDatasetはPythonで書かれた外部依存性のない独立したライブラリとして開発されており、将来的にPyTorchのサブパッケージとして取り込まれるための提案がなされています(RCF 38419)。
同様のライブラリとしてTensorFlowのTFRecordがありますが、WebDatasetではPOSIX tarによるファイルベースを採用しており、シリアライズされた形式に変換する必要がないという特徴があります。
$ pip install webdataset
$ pip install git+https://github.com/tmbdev/webdataset.git
ここでは文書画像のデータセットであるPubLayNetを用いて説明してきたと思います。PubLayNetは、ドキュメント画像の大規模なデータセットで、そのレイアウトには、境界ボックスと多角形のセグメンテーションの両方で注釈が付けられています。
まずデータセットの一部として、290MB程度のシャードファイルを適当な場所(ここでは/tmp
$ curl -L "http://storage.googleapis.com/nvdata-publaynet/publaynet-train-000000.tar" -o "/tmp/publaynet_000000.tar"
実態は普通のtarファイルで、中身は画像ファイル(png) とメタデータ(json) のセットが985組アーカイブされています。
$ tar -tf /tmp/publaynet_000000.tar | head PMC4991227_00003.json PMC4991227_00003.png PMC4537884_00002.json PMC4537884_00002.png PMC4323233_00003.json PMC4323233_00003.png PMC5429906_00004.json PMC5429906_00004.png PMC5592712_00002.json PMC5592712_00002.png $ tar -tf /tmp/publaynet_000000.tar | wc -l 1970
このように、WebDatasetでは特殊なデータ形式に変更することなくPOSIX tar形式でアーカイブされたファイルを読み出すことができます。データセットは標準のtarコマンドでも容易に作成することができます。
4-3. ローカルディスクからの読み出し
import torch import webdataset as wds url = "/tmp/publaynet_000000.tar" # ローカルファイルパスをセットします. dataset = wds.WebDataset(url) # データパイプラインの定義(デコード、タプル化). dataset = dataset.decode("rgb").to_tuple("png", "json")
webdataset.WebDataset
IterableDataset
print(isinstance(dataset, torch.utils.data.IterableDataset)) # True # データを取得します. image, json = next(iter(dataset)) print(image.shape, image.dtype, type(json)) # (794, 610, 3) float32 <class 'dict'>
データセットの前処理として、サンプリングされたデータに対し、任意の関数をmap()
def preprocess(sample): image, json = sample try: label = json["annotations"][0]["category_id"] except Exception: label = 0 return image, label dataset = dataset.map(preprocess)
さらにデータ拡張として、画像データから、さらにランダムに256x256のイメージサイズに切り出す処理を入れてみます。compose()
from random import randrange def get_patches(source): for sample in source: image, label = sample # サンプリングされた画像のheight/widthを取得します. h, w = image.shape[:2] for _ in range(16): y, x = randrange(h - 256), randrange(w - 256) patch = image[y : y + 256, x : x + 256] yield (patch, label) dataset = dataset.compose(get_patches) dataset = dataset.shuffle(10000) # バッファーサイズ=10000でシャッフルします.
データセットの準備としてはこれで終わりです。 最後にPyTorch標準のDataLoader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, num_workers=4) images, labels = next(iter(dataloader)) print(images.shape, labels.shape) # torch.Size([64, 256, 256, 3]) torch.Size([64])
また、複数のファイルを読み出すには最初のURLをpublaynet_{000000...000009}.tar
import torch import webdataset as wds url = "/tmp/publaynet_{000000..000009}.tar" dataset = wds.WebDataset(url) # 以下同じ
先ほどはtarファイルをローカルにダウンロードしてデータセットとしました。WebDatasetではWebサーバのURLを直接設定することが可能です。url
http://storage.googleapis.com/nvdata-publaynet/publaynet-train-{000000..000009}.tar
import torch import webdataset as wds url = "http://storage.googleapis.com/nvdata-publaynet/publaynet-train-{000000..000009}.tar" dataset = wds.WebDataset(url) dataset = dataset.decode("rgb").to_tuple("png", "json") dataset = dataset.map(preprocess) dataset = dataset.compose(get_patches) dataset = dataset.shuffle(10000) dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, num_workers=4) images, labels = next(iter(dataloader)) print(images.shape, labels.shape) # torch.Size([64, 256, 256, 3]) torch.Size([64])
ローカルストレージの場合とURLを変えるだけで、他のコードを変更することなくWebサーバから連続的に訓練用のバッチを取得することができます。WebDatasetでは任意の前処理やデータ拡張を定義してデータパイプラインとしてストリーミングできるため、学習前に個別にファイルをダウンロードする必要はありません。
最後にAmazon S3に置かれたファイルを取得するケースをみていきます。S3から学習データセットを利用する場合、3つの選択肢があります。
オブジェクトをバイトストリームとして直接ロードする
S3バケットをファイルシステムとしてマウントする
SageMakerパイプモードを(解析して)利用する
SageMakerパイプモードはAmazon SageMakerが提供するS3に保存されているデータをやり取りするための専用APIです。パイプモードを使用すると、データは専用のLinux FIFOパイプを介して高速にストリーミングされます。データバイナリを解析する必要があるため、ここでは前の二つの方法を見ていきたいと思います。
オブジェクトのストリーム[1]
Amazon S3の場合、ローカルストレージ/Webサーバと異なり、直接URLを指定することができません。そこでPythonのBoto3ライブラリを用い、S3のオブジェクトを直接(ストレージを介さずに)メモリのバイトストリームとして取り込みます。
import io import re import boto3 client = boto3.client( "s3", aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"] ) def get_stream(path): stream = io.BytesIO() _, bucket, key, _ = re.split("s3://(.*?)/(.*)$", path) client.download_fileobj(bucket, key, byte_io) stream.seek(0) return stream
さらにこれをWebDataset.tariterators.url_opener
webdataset.WebDataset
def url_opener(data, handler=reraise_exeption, **kwd): for sample in data: url = sample["url"] try: stream = get_stream(url) sample.update(stream=stream) yield sample except Excep
原文を表示
機械学習
PyTorch
自動運転
WebDataset tech

深層学習をする上で、最も大切なマシンスペックを聞かれたら何と答えますか? GPUのTensor性能、VRAM、GPUの数、CPU性能、メモリ、… 問題によって正解は異なりますね。
しかし、特に大規模なデータセットで機械学習する場合では、しばしばネットワーク帯域とストレージシステムのディスクI/Oによって制限されます。この記事ではそのような課題に対して、学習側でどのようにデータを扱うかを見ていきたいと思います。
こんにちは、TURING MLチームです。TURINGはEnd-to-Endな深層学習モデルでLv5完全自動運転車の開発を目指す会社です。
私たちは自動運転モデルを動かすため、可視域のカメラセンサによる画像で学習し、カメラ映像のみから車体の操作や経路選択、安全性の判断を行わせています。(実際の車を動かす事例はこちらの記事をご覧ください。)
そのため、機械学習のために大量の画像データが必要になってきます。TURINGでは2022年に500時間、2023年に50,000時間の公道上の走行データをカメラによって収集する計画を立てています。50,000時間、というとピンと来ないかと思いますが、仮に平均時速25kmだとすると、合計で125万kmになります。これは日本の道路総延長(=128万km) に匹敵する距離です。(もちろん、日本の全ての道路を走行するわけではなく、都市部を中心に同じ道をさまざまな角度・条件で撮影していくことになります。)

重要なことは、データが大量の動画という形で取得されるという点です。動画はエンコードされて圧縮されている状態でも~1GB/時間、複数カメラで機械学習用のテンソルに成形すると数十GB/時間程度にもなります。そのため、50,000時間の走行でおよそ1ペタバイト程度のデータとなります。
TURINGはこのような大規模な機械学習に向け、ストレージシステムの検討やI/Oパフォーマンスの測定、そして100基単位でのGPUによる並列分散学習のための技術検証を進めています。この記事では、開発段階の小規模なデータセットからペタバイトスケールの機械学習まで対応できるデータローダーの仕組みの一例として、PyTorchで学習するケースを紹介していきたいと思います。
- 学習データをロードする5つのシナリオ
機械学習でデータセットを計算サーバに転送するには5つの方法があります。
(1) オンメモリ (2) ローカルディスク (3) Web/ファイルサーバー (4) クラウドストレージ (5) 高速分散ストレージシステム
全てのデータがメモリ上に収まるサイズのデータセットであれば、多くのケースで特別なデータハンドリングは必要ありません。一度データを読み出せば高速で処理が可能です。一方、画像分類等の大きな訓練用データセットは、しばしばメモリサイズを超過します。そのようなケースでは、ディスク上に配置されたファイルを逐次読み込む必要があります。
インターネット上に公開されているデータセットは、前もって全てダウンロードすればローカルディスクと同様に扱えますが、ネットワーク帯域によっては多くの時間が要求されます。
また、Amazon S3など、クラウドストレージ上に保存されているデータを読み出す場合、ファイルシステムとしてマウントしたり、公式のS3 IO datapipesを使ってデータパイプラインを作成します。さらに大規模な環境では、LustreやGPFSなど高速な並列分散ファイルシステムが採用されることがありますが、このような場合でもやはりネットワーク帯域やファイルシステムのパフォーマンスに影響を受けます。
一方、どのようなストレージシステムを利用するにせよ、機械学習のデータセットへのアクセスは共通して下記のような特徴を持ちます。このようなデータアクセスは機械学習特有なもので、既存のファイルシステムとうまく適合しない可能性があります。
均一にランダムなアクセスパターンを持つ
多くの(ときには数十億の)ファイルで構成されている
学習前にデータをシャッフルや前処理、データ拡張する必要がある
小さなデータセットで開発/テストをしてから、大きなデータセットにスケールアップさせていく場合、ストレージシステムを変更する状況がしばしば生じます。学習データのスケールにあわせ、システムごとにデータローダーの実装やパフォーマンスチューニングをするのは大変にコストがかかります。そのため、ファイル形式やロードのためのコードをできるだけ変更せずに対応したいわけですが、どうしたらよいでしょうか?
- WebDatasetとは
PyTorchに対するWebDatasetライブラリは、このようなデータ読み込みの問題を解決し、(1)~(5)までの全てのシナリオに対してペタスケールまでの統一的で効率のよいアクセスを提供してくれます。
WebDataset is an ideal solution for training on petascale datasets kept on high performance distributed data stores like AIStore, AWS/S3, and Google Cloud.
WebDataset also is very useful for such smaller datasets, and it can easily be used for developing and testing on small datasets and then scaling up to large datasets by simply using more shards.
WebDatasetはBigDat2019で発表されたHigh Performance I/O For Large Scale Deep Learningで示された大規模な深層学習のためのデータセット機構とその実装です。WebDatasetは任意のストレージシステムにデータを数十~数百MBごとにシャーディング(分割)して配置し、シーケンシャルに読み込むことででストリーミングによるアクセスを可能にしています。
WebDatasetはPythonで書かれた外部依存性のない独立したライブラリとして開発されており、将来的にPyTorchのサブパッケージとして取り込まれるための提案がなされています(RCF 38419)。
同様のライブラリとしてTensorFlowのTFRecordがありますが、WebDatasetではPOSIX tarによるファイルベースを採用しており、シリアライズされた形式に変換する必要がないという特徴があります。
$ pip install webdataset
$ pip install git+https://github.com/tmbdev/webdataset.git
ここでは文書画像のデータセットであるPubLayNetを用いて説明してきたと思います。PubLayNetは、ドキュメント画像の大規模なデータセットで、そのレイアウトには、境界ボックスと多角形のセグメンテーションの両方で注釈が付けられています。
Publaynetの画像データ
まずデータセットの一部として、290MB程度のシャードファイルを適当な場所(ここでは/tmp
$ curl -L "http://storage.googleapis.com/nvdata-publaynet/publaynet-train-000000.tar" -o "/tmp/publaynet_000000.tar"
実態は普通のtarファイルで、中身は画像ファイル(png) とメタデータ(json) のセットが985組アーカイブされています。
$ tar -tf /tmp/publaynet_000000.tar | head PMC4991227_00003.json PMC4991227_00003.png PMC4537884_00002.json PMC4537884_00002.png PMC4323233_00003.json PMC4323233_00003.png PMC5429906_00004.json PMC5429906_00004.png PMC5592712_00002.json PMC5592712_00002.png $ tar -tf /tmp/publaynet_000000.tar | wc -l 1970
このように、WebDatasetでは特殊なデータ形式に変更することなくPOSIX tar形式でアーカイブされたファイルを読み出すことができます。データセットは標準のtarコマンドでも容易に作成することができます。
4-3. ローカルディスクからの読み出し
import torch import webdataset as wds url = "/tmp/publaynet_000000.tar" # ローカルファイルパスをセットします. dataset = wds.WebDataset(url) # データパイプラインの定義(デコード、タプル化). dataset = dataset.decode("rgb").to_tuple("png", "json")
webdataset.WebDataset
IterableDataset
print(isinstance(dataset, torch.utils.data.IterableDataset)) # True # データを取得します. image, json = next(iter(dataset)) print(image.shape, image.dtype, type(json)) # (794, 610, 3) float32 <class 'dict'>
データセットの前処理として、サンプリングされたデータに対し、任意の関数をmap()
def preprocess(sample): image, json = sample try: label = json["annotations"][0]["category_id"] except Exception: label = 0 return image, label dataset = dataset.map(preprocess)
さらにデータ拡張として、画像データから、さらにランダムに256x256のイメージサイズに切り出す処理を入れてみます。compose()
from random import randrange def get_patches(source): for sample in source: image, label = sample # サンプリングされた画像のheight/widthを取得します. h, w = image.shape[:2] for _ in range(16): y, x = randrange(h - 256), randrange(w - 256) patch = image[y : y + 256, x : x + 256] yield (patch, label) dataset = dataset.compose(get_patches) dataset = dataset.shuffle(10000) # バッファーサイズ=10000でシャッフルします.
データセットの準備としてはこれで終わりです。 最後にPyTorch標準のDataLoader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, num_workers=4) images, labels = next(iter(dataloader)) print(images.shape, labels.shape) # torch.Size([64, 256, 256, 3]) torch.Size([64])
また、複数のファイルを読み出すには最初のURLをpublaynet_{000000...000009}.tar
import torch import webdataset as wds url = "/tmp/publaynet_{000000..000009}.tar" dataset = wds.WebDataset(url) # 以下同じ
先ほどはtarファイルをローカルにダウンロードしてデータセットとしました。WebDatasetではWebサーバのURLを直接設定することが可能です。url
http://storage.googleapis.com/nvdata-publaynet/publaynet-train-{000000..000009}.tar
import torch import webdataset as wds url = "http://storage.googleapis.com/nvdata-publaynet/publaynet-train-{000000..000009}.tar" dataset = wds.WebDataset(url) dataset = dataset.decode("rgb").to_tuple("png", "json") dataset = dataset.map(preprocess) dataset = dataset.compose(get_patches) dataset = dataset.shuffle(10000) dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, num_workers=4) images, labels = next(iter(dataloader)) print(images.shape, labels.shape) # torch.Size([64, 256, 256, 3]) torch.Size([64])
ローカルストレージの場合とURLを変えるだけで、他のコードを変更することなくWebサーバから連続的に訓練用のバッチを取得することができます。WebDatasetでは任意の前処理やデータ拡張を定義してデータパイプラインとしてストリーミングできるため、学習前に個別にファイルをダウンロードする必要はありません。
最後にAmazon S3に置かれたファイルを取得するケースをみていきます。S3から学習データセットを利用する場合、3つの選択肢があります。
オブジェクトをバイトストリームとして直接ロードする
S3バケットをファイルシステムとしてマウントする
SageMakerパイプモードを(解析して)利用する
SageMakerパイプモードはAmazon SageMakerが提供するS3に保存されているデータをやり取りするための専用APIです。パイプモードを使用すると、データは専用のLinux FIFOパイプを介して高速にストリーミングされます。データバイナリを解析する必要があるため、ここでは前の二つの方法を見ていきたいと思います。
オブジェクトのストリーム[1]
Amazon S3の場合、ローカルストレージ/Webサーバと異なり、直接URLを指定することができません。そこでPythonのBoto3ライブラリを用い、S3のオブジェクトを直接(ストレージを介さずに)メモリのバイトストリームとして取り込みます。
import io import re import boto3 client = boto3.client( "s3", aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"] ) def get_stream(path): stream = io.BytesIO() _, bucket, key, _ = re.split("s3://(.*?)/(.*)$", path) client.download_fileobj(bucket, key, byte_io) stream.seek(0) return stream
さらにこれをWebDataset.tariterators.url_opener
webdataset.WebDataset
def url_opener(data, handler=reraise_exeption, **kwd): for sample in data: url = sample["url"] try: stream = get_stream(url) sample.update(stream=stream) yield sample except Exception as e: e.args = e.args + (url,) if handler(e): continue else: break # url_openerをオーバーライドします. wds.tariterators.url_opener = url_opener urls = [f"s3://<path of dataset>/publaynet_{i:06d}.tar" for i in range(10)] dataset = wds.WebDataset(urls)
Amazon S3はS3Fs、goofysなどFUSEベースのライブラリでマウントすることができます。またここでは扱いませんが、WebDataset以外でも、公式のS3 Pytorchプラグインをを使用することで、IterableDataset
注: はじめにstart methodをspawnにする必要があります. torch.multiprocessing.set_start_method("spawn") import s3fs fs = s3fs.S3FileSystem( key=os.environ["AWS_ACCESS_KEY_ID"], secret=os.environ["AWS_SECRET_ACCESS_KEY"] ) def url_opener(data, handler=reraise_exeption, **kwd): for sample in data: url = sample["url"] try: stream = fs.open(url.replace("s3://", ""), mode="rb") sample.update(stream=stream) except Exception as e: e.args = e.args + (url,) if handler(e): continue else: break # url_openerをオーバーライドします. wds.tarietators.url_opener = url_opener urls = [f"s3://<path of dataset>/publaynet_{i:06d}.tar" for i in range(10)] dataset = wds.WebDataset(urls)
WebDatasetで扱うのは単なるtarファイルなので、通常はtarコマンドを使用するだけで作成可能です。
$ tar --sort=name -cf dataset.tar dataset/
またはGoで実装されたtarpをインストールし、tarp create
tarpのインストール $ go get -v github.com/tmbdev/tarp/tarp
また、既存のデータセットに対して、webdataset.TarWriter
sink = wds.TarWriter("dest.tar") for index, (input, output) in dataset: sink.write({ "__key__": "sample%06d" % index, "input.png": input, "output.cls": output, }) sink.close()
最後に、WebDatasetでデータを取得する場合のパフォーマンスを計測してみたいと思います。データセットはPubLaynetを用いて、1エポックあたりの平均読み込み速度を計測します。対象としたのは10GB相当のシャードファイルで約34500個分のイメージデータセットです。
ローカルディスクにpng/jsonとして展開したものをPyTorch標準のDatasetで読み込み
ローカルディスクのtarファイルをWebDatasetで読み込み
Webサーバーからダウンロードし、PyTorch標準のDatasetで読み込み
WebサーバーからWebDatasetでストリーミングして読み込み
Amazon S3からWebDatasetでバイトストリームで読み込み
Amazon S3からWebDatasetでFUSEマウントで読み込み
1epochあたりのロード時間
ローカルディスク + Dataset
ローカルディスク + WebDataset
(Webサーバーからの直接ダウンロード時間)
Webサーバーダウンロード + Dataset
Webサーバー + WebDataset
(Amazon S3からの直接ダウンロード時間)
Amazon S3 + WebDataset (バイトストリーム)
Amazon S3 + WebDataset (FUSEマウント)
ローカルディスクで展開済のデータではPyTorch標準のDatasetが上回っていますが、Webサーバー・Aamzon S3からストリーミングする場合ではファイルを直接ダウンロードするのと同等の時間で(つまり学習中にネットワーク帯域をフルに使って)データセットをロードすることができるという結果になりました。シンプルな実装でさまざまなデータローディングに対応できる、スケールアウトも容易な点は大きなメリットかと思います。
- おわりに、そして超大規模学習に向けて
今回はペタバイトスケールの機械学習にも対応可能なデータローダーの検証としてWebDatasetの使い方やパフォーマンスを紹介しました。あくまで機械学習の実装視点からのものですので、実際の大規模学習ではストレージシステムやネットワーク、並列分散で計算させるGPUクラスタなど、さまざまな要素が必要になってきます。
そして、TURINGではこのような大規模な機械学習モデルをつくるエンジニアを募集しています。
- 真に大規模な深層学習モデルの設計・実験・評価 - ペタバイトスケールのデータストレージシステム/クラウドインフラ構築・管理 - 走行データ収集のためのiOS/Androidアプリ開発・運用
詳しくはTURINGの採用ページをご覧ください。一緒に完全な車をつくりませんか? ご不明な点があればTURING MLチーム 担当者 (Yu Yamaguchi)までお気軽にDMしてください。
https://www.wantedly.com/projects/1024347 https://www.wantedly.com/companies/turing-motors/projects
Training in PyTorch from Amazon S3 ↩︎
Tech Blog - TuringPublication人類未到の完全自動運転を目指すスタートアップ・チューリング株式会社の公式テックブログです。

今日のまとめ
AI日報で今日の重要ニュースをまとめ読み