以前見かけたこちらのリポジトリ (tascj/offload_adam)を、huggingface/transformersに実装されているTrainerを用いたFSDPと併用してみる話です。リポジトリ作者様による解説はこちらのリンクを参照ください。
本記事のまとめ
- tascj/offload_adamに実装されているAdam (AdamW) はstochastic roundingを用いてfp32からbf16への変換を行っています。このおかけで、fp32の学習時とほぼ同じlossの推移をbf16の学習で実現することができます。
- 私が検証した時点ではFSDP1のみに対応していたため、FSDP2での動作確認はできませんでした。
- transformersのTrainerと併用し、accelerateを用いてプログラムを実行する際は、configの指定が重要です。mixed precisionを有効化しているとモデルパラメータがfp32にupcastされるため、省メモリ化の恩恵を受けられません。SFTConfig等のbf16という引数や、accelerate configに記載するmixed_precision等の項目を適切に設定しましょう。
- transformersのTrainerを用いる場合、optimizer_cls_and_kwargsという引数を設定することでカスタムoptimizerを使用することができます。
補足: 同様の機能を持つライブラリに関する調査
- ざっと調べた限り、パラメーター更新等がtriton/cudaで記述されている かつ stochastic roundingを採用している 実装は tascj/offload_adam 以外に見つかりませんでした。
- torchaoにもbfloat16対応のAdamWが実装されています。しかし、FSDP2には完全に対応しているわけではないようであり、今回の検証ではエラーによって実行を完了できませんでした。FSDP1を用いて今回検証した限り、torchaoの実装は tascj/offload_adam と同様のメモリ使用量で、30%程度遅いという結果になりました。
この記事は役に立ちましたか?
もし参考になりましたら、下記のボタンで教えてください。
コメント