BF16に対応したOptimizerを用いてFSDPを用いた並列学習を高速化する

以前見かけたこちらのリポジトリ (tascj/offload_adam)を、huggingface/transformersに実装されているTrainerを用いたFSDPと併用してみる話です。リポジトリ作者様による解説はこちらのリンクを参照ください。

本記事のまとめ

  • tascj/offload_adamに実装されているAdam (AdamW) はstochastic roundingを用いてfp32からbf16への変換を行っています。このおかけで、fp32の学習時とほぼ同じlossの推移をbf16の学習で実現することができます。
  • 私が検証した時点ではFSDP1のみに対応していたため、FSDP2での動作確認はできませんでした。
  • transformersのTrainerと併用し、accelerateを用いてプログラムを実行する際は、configの指定が重要です。mixed precisionを有効化しているとモデルパラメータがfp32にupcastされるため、省メモリ化の恩恵を受けられません。SFTConfig等のbf16という引数や、accelerate configに記載するmixed_precision等の項目を適切に設定しましょう。
  • transformersのTrainerを用いる場合、optimizer_cls_and_kwargsという引数を設定することでカスタムoptimizerを使用することができます。

補足: 同様の機能を持つライブラリに関する調査

  • ざっと調べた限り、パラメーター更新等がtriton/cudaで記述されている かつ stochastic roundingを採用している 実装は tascj/offload_adam 以外に見つかりませんでした。
  • torchaoにもbfloat16対応のAdamWが実装されています。しかし、FSDP2には完全に対応しているわけではないようであり、今回の検証ではエラーによって実行を完了できませんでした。FSDP1を用いて今回検証した限り、torchaoの実装は tascj/offload_adam と同様のメモリ使用量で、30%程度遅いという結果になりました。
スポンサーリンク

この記事は役に立ちましたか?

もし参考になりましたら、下記のボタンで教えてください。

関連記事

コメント

この記事へのコメントはありません。

CAPTCHA