torch.jit.traceを試してみる
torch.jit.script(model)
がうまくいかない場合でも、
torch.jit.trace(model, sample_input)
ならうまくいく場合があります。torch.jit.traceを使った場合、変換後のモデルに入力するtensorはtrace作成時と同じ形のものを使う必要があります。
torch.autograd.Functionを使っていないか?
TorchScriptは、torch.autograd.Functionを使ったカスタム演算に対応していないようです。メモリ削減のためにSwish関数を実装する場合にtorch.autograd.Functionを使うことがあるようです。
自分の場合は、torch.jit.traceを使って変換はできたものの保存することができませんでした。
ひょっとしたらPythonではtorch.autograd.FunctionがあってもTorchScriptとして動くのかもしれませんが、私は検証していません。
その他
少し情報が古いかもしれませんが、他にも注意すべきことがあるようです。
この記事は役に立ちましたか?
もし参考になりましたら、下記のボタンで教えてください。
コメント