Path: blob/master/site/ja/datasets/tfless_tfds.ipynb
25115 views
Copyright 2023 The TensorFlow Datasets Authors.
Jax と PyTorch 用の TFDS
TFDS は常に フレームワーク非依存型でした。たとえば、NumPy 形式のデータセットを簡単に読み込んで、Jax と PyTorch で使用することができます。
TensorFlow とそのデータ読み込みソリューション(tf.data
)は、設計上、API の第一級市民です。
TensorFlow を使用せずに NumPy のみでデータを読み込めるように、TFDS を拡張しました。これは、Jax や PyTorch などの ML での使用に便利であり、実際に PyTorch ユーザーの場合、TensorFlow では以下のことが発生する可能性があります。
GPU/TPU メモリの予約
CI/CD でのビルド時間の長期化
ランタイム時のインポートの長期化
TensorFlow は、データセットを読み取る際の依存関係ではなくなりました。
ML パイプラインがサンプルを読み込んで解読し、モデルに提供するには、データローダーが必要です。データローダーは、「ソース/サンプラー/ローダー」パラダイムを使用します。
データソースは、TFDS データセットからオンザフライ方式でサンプルにアクセスして解読します。
インデックスサンプラーは、レコードが処理される順序を決定します。これは、レコードを読み取る前にグローバル変換(グローバルシャッフル、シャーディング、複数のエポックの反復など)を実装するのに重要です。
データローダーは、データソースとインデックスサンプラーを利用して、読み込みをオーケストレーションします。パフォーマンスの最適化が可能です(プリフェッチ、マルチプロセッシング、またはマルチスレッドなど)。
要約
tfds.data_source
は、データソースを作成する API で、以下を目的としています。
純粋な Python パイプラインでの高速プロトタイピング
大規模なデータ集約型 ML パイプラインの管理
セットアップ
必要な依存関係をインストールしてインポートしましょう。
データソース
データソースは基本的に Python シーケンスです。そのため、以下のプロトコルを実装する必要があります。
警告: この API は現在も活発に開発されています。特に、現時点では、__getitem__
は入力で int
と list[int]
をサポートする必要があります。将来的には、標準に従って、おそらく int
のみがサポートされます。
基盤のファイル形式は有効なランダムアクセスをサポートする必要があります。現時点では、TFDS は array_record
に依存しています。
array_record
は、Riegeli から派生した新しいファイル形式です。IO 効率の新境地を達成しています。特に、ArrayRecord はレコードインデックスによる同時読み取り、書き込み、およびランダムアクセスをサポートしています。ArrayRecord は Riegeli を基盤としているため、同じ圧縮アルゴリズムをサポートしています。
fashion_mnist
はコンピュータビジョン用の共通データセットです。以下を使用するだけで、TFDS で ArrayRecord ベースのデータを取得することができます。
tfds.data_source
は便利なラッパーで、以下に相当します。
これは、データソースのディクショナリを出力します。
download_and_prepare
が実行し、レコードファイルを生成したら、TensorFlow は不要になります。すべては Python/NumPy で処理されます!
TensorFlow をアンインストールして、別のサブプロセスでデータソースを読み込みなおして、このことを確認してみましょう。
今後のバージョンでは、データセットの準備も TensorFlow を使用せずに行えるようにする予定です。
データソースには長さがあります。
以下のようにして、データセットの最初の要素にアクセスすると...
他の要素へのアクセスと同じように安価に行えます。これが、ランダムアクセスの定義です。
特徴量は NumPy DTypes(TensorFlow DTypes ではなく)を使用するようになりました。以下のようにして、特徴量を検査することができます。
特徴量の詳細は、ドキュメントで確認できます。ここでは、画像の形状とクラスの数を取得できます。
純粋な Python で使用する
Python でデータソースを反復することで、それを消費できます。
PyTorch で使用する
PyTorch は、ソース/サンプラー/ローダー構成のパラダイムを使用します。Torch では、「データソース」のことを「データセット」と呼んでいます。torch.utils.data
には、有効な入力パイプラインを Torch でビルドするために必要なすべての情報が含まれます。
通常のマップスタイルのデータセットとして、TFDS データソースを使用することができます。
まず、Torch をインストールしてインポートします。
トレーニング用のデータソースとテスト用のデータソースはすでに定義済みです(順に、ds['train']
と ds['test']
)。サンプラーとローダーを定義しましょう。
PyTorch で、最初のサンプルを使って、単純なロジスティック回帰をトレーニングし、評価します。
近日公開: JAX と使用する
Grain と緊密に作業を続けています。Grain はオープンソースの高速で決定論的な Python 用データローダーです。ご期待ください!
その他の資料
詳細については、tfds.data_source
API ドキュメントをご覧ください。