pytorchで推論を行うときはmodel.eval()を行おう
pytorchで推論を行いたい。
pytorchでDeep Learningのモデルを組み立てるとき、Batch Normalizationを用いるのはもはや常識だ。しかし、Batch Normalizationを含むモデルで推論を行うとエラーが起きる。
以下のコードをみて欲しい。バッチサイズが1のときには実行できないのだ。バッチサイズが1だとnormalizationを行う余地がないのだが、内部では計算ができないようになっている。
そこで、推論を行う際は、model.eval()
を行うことで実行できるようになる。
中で何を行なっているのかは下を参照すること。まだ中身は読んでいない。
import torch class Foo(torch.nn.Module): def __init__(self): super(Foo, self).__init__() self.bn = torch.nn.BatchNorm1d(3) def forward(self, x): return self.bn(x) def main(): x1 = torch.ones(1,3) x2 = torch.ones(2,3) model = Foo() y = model(x1) # this line will fail to run y = model(x2) # this line can run without error model.eval() # switch to inference mode y = model(x1) # this line can run without error model.train() # switch to traing mode if __name__ == '__main__': main()
What does model.eval() do for batchnorm layer? - PyTorch Forums