pytorchで推論を行うときはmodel.eval()を行おう

pytorchで推論を行いたい。

pytorchでDeep Learningのモデルを組み立てるとき、Batch Normalizationを用いるのはもはや常識だ。しかし、Batch Normalizationを含むモデルで推論を行うとエラーが起きる。

以下のコードをみて欲しい。バッチサイズが1のときには実行できないのだ。バッチサイズが1だとnormalizationを行う余地がないのだが、内部では計算ができないようになっている。

そこで、推論を行う際は、model.eval()を行うことで実行できるようになる。 中で何を行なっているのかは下を参照すること。まだ中身は読んでいない。

import torch

class Foo(torch.nn.Module):
    def __init__(self):
        super(Foo, self).__init__()
        self.bn = torch.nn.BatchNorm1d(3)
    
    def forward(self, x):
        return self.bn(x)

def main():
    x1 = torch.ones(1,3)
    x2 = torch.ones(2,3)
    model = Foo()
    y = model(x1) # this line will fail to run 
    y = model(x2) # this line can run without error
    model.eval() # switch to inference mode
    y = model(x1) # this line can run without error
    model.train() # switch to traing mode

if __name__ == '__main__':
    main()

What does model.eval() do for batchnorm layer? - PyTorch Forums