pytorchのdataloaderとtransformを自作する

pytorchのdataloaderとtransformを自作する

概要

最小構成でのサンプルにより、dataloaderとtransformの自作方法を例示する。

ポイントは前処理を行うtorchvision.transformsに関する部分には、自由に関数を入れることができるというところである。ToTensor()を使用すると形がndarrayからtensorに変化してしまうので注意する必要がある。ToTensor()より前なら汎用的な関数を使用できるので便利だろう。

前処理の定義方法は、前処理の関数をlistとしてtorchvision.transforms.Compose()に渡すことで行うことができる。Compose()は必須ではないが、これを使うと複数の前処理を連鎖して実行できるようになるのだ。渡すlistの中身は関数オブジェクトか、__call__()メソッドを持つオブジェクトであり、前から順番に実行される。以下の例ではfが関数でtorchvision.transforms.ToTensor()がオブジェクトである。pytorchが用意してくれている関数は基本的に__call__()メソッドを持つオブジェクトである。オブジェクトにするとパラメーターを変更したいときに元の関数を変更しなくともよいので便利である。もちろん関数を返す関数を作成してもよいので、これは好みによるかなとおもう。

DataLoadernum_workersの意味

DataLoaderで最終なdataの読み込みを行うイテレータを作成する、単なるイテレータならわざわざpytorchのDataLoaderを使わずとも簡単に自作できる。pytorchのDataloaderの場合は__len__()を持ってる必要がある分、generatorが使用できないのでむしろ不便である。 しかし、DataLoaderにはloopを回している間に次のイテレータの実行をバックグラウンドで行なってくれるという大きな利点がある。前処理をおこなうworkerという名前で呼ばれるプロセスを立ち上げて、loopの中身を実行している間に事前に前処理を行なってくれるのだ。ちなみに、pythonは言語としてマルチスレッドに対応していないので、プロセスを立ち上げるという対応になっているのだと思う。

num_workersはそのプロセス数を指定する物である。num_workersの数だけworkerプロセスが立ち上げられる。気をつけなければならないのは、一つのイテレータないで並列化されるのではなく、それぞれのプロセスは別々のイテレータの前処理をおこなうという点だ。例えばnum_workers=5とすると5個先のイテレータまでバックグラウンドで実行してくれる。 また、workerは常に投機的に実行し続けるので、同時に実行する前処理のプロセス数はnum_workersで指定した数だが、早く終わるとさらに先の前処理までおこなうようになっている。 無限におこなうはずはないと思うが、どこまで投機的に実行するのかは追い切れていない。

所感としては、loopの中と前処理が毎回同じ実行コストで低コストあれば2で十分だが、ばらつきがある場合は、長くかかるloopや前処理のバッファを貯めるために少し多めに設定しても良いのではと思う。また、ほとんどの処理がシングルスレッドならGPUでの計算を邪魔しない程度に増やしても良いだろう。逆にデータのロードの際にIO waitがボトルネックになる場合は多くしないほうがよさそうではある。

sample code

import torch
import torchvision
import numpy as np
import time


def f(x):
    time.sleep(0.2)
    return 2*x


class MyTransform:
    def __init__(self, a):
        self.a = a

    def __call__(self, x):
        return self.a * x


class Mydatasets(torch.utils.data.Dataset):
    def __init__(self, leng, transform1 = None):
        self.transform1 = transform1

        self.dataset = list(np.arange(leng).reshape(leng,1,1))
        self.datanum = leng

    def __len__(self):
        return self.datanum

    def __getitem__(self, idx):
        out_data = self.dataset[idx]

        if self.transform1:
            out_data = self.transform1(out_data)
        return out_data
    
    def __str__(self):
        return f"""my original dataloader. 
        data num: {self.datanum}"""


if __name__ == "__main__":
    trans = torchvision.transforms.Compose([f, MyTransform(3), torchvision.transforms.ToTensor()])
    dataset = Mydatasets(10, transform1=trans)
    print(dataset)
    trainloader = torch.utils.data.DataLoader(dataset, batch_size = 10, shuffle = False, num_workers = 2)
    for z in trainloader:
        print(z)

    # for check worker's behaviour
    for i, z in enumerate(trainloader):
        print(z[0][0][0][0])
        if i%5 == 0 and i is not 0:
           time.sleep(10)
        else:
            time.sleep(0.1)