PyTorchモデルがTorchScriptに変換できないときに確認すること

torch.jit.traceを試してみる

torch.jit.script(model)

がうまくいかない場合でも、

torch.jit.trace(model, sample_input)

ならうまくいく場合があります。torch.jit.traceを使った場合、変換後のモデルに入力するtensorはtrace作成時と同じ形のものを使う必要があります。

torch.autograd.Functionを使っていないか?

TorchScriptは、torch.autograd.Functionを使ったカスタム演算に対応していないようです。メモリ削減のためにSwish関数を実装する場合にtorch.autograd.Functionを使うことがあるようです。
自分の場合は、torch.jit.traceを使って変換はできたものの保存することができませんでした。

ひょっとしたらPythonではtorch.autograd.FunctionがあってもTorchScriptとして動くのかもしれませんが、私は検証していません。

その他

少し情報が古いかもしれませんが、他にも注意すべきことがあるようです。

この記事は役に立ちましたか?

もし参考になりましたら、下記のボタンで教えてください。

関連記事

コメント

この記事へのコメントはありません。

CAPTCHA