pytorchのdataloaderとtransformを自作する
pytorchのdataloaderとtransformを自作する
概要
最小構成でのサンプルにより、dataloaderとtransformの自作方法を例示する。
ポイントは前処理を行うtorchvision.transforms
に関する部分には、自由に関数を入れることができるというところである。ToTensor()
を使用すると形がndarrayからtensorに変化してしまうので注意する必要がある。ToTensor()
より前なら汎用的な関数を使用できるので便利だろう。
前処理の定義方法は、前処理の関数をlistとしてtorchvision.transforms.Compose()
に渡すことで行うことができる。Compose()
は必須ではないが、これを使うと複数の前処理を連鎖して実行できるようになるのだ。渡すlistの中身は関数オブジェクトか、__call__()
メソッドを持つオブジェクトであり、前から順番に実行される。以下の例ではf
が関数でtorchvision.transforms.ToTensor()
がオブジェクトである。pytorchが用意してくれている関数は基本的に__call__()
メソッドを持つオブジェクトである。オブジェクトにするとパラメーターを変更したいときに元の関数を変更しなくともよいので便利である。もちろん関数を返す関数を作成してもよいので、これは好みによるかなとおもう。
DataLoader
のnum_workers
の意味
DataLoaderで最終なdataの読み込みを行うイテレータを作成する、単なるイテレータならわざわざpytorchのDataLoaderを使わずとも簡単に自作できる。pytorchのDataloaderの場合は__len__()
を持ってる必要がある分、generatorが使用できないのでむしろ不便である。
しかし、DataLoader
にはloopを回している間に次のイテレータの実行をバックグラウンドで行なってくれるという大きな利点がある。前処理をおこなうworkerという名前で呼ばれるプロセスを立ち上げて、loopの中身を実行している間に事前に前処理を行なってくれるのだ。ちなみに、pythonは言語としてマルチスレッドに対応していないので、プロセスを立ち上げるという対応になっているのだと思う。
num_workers
はそのプロセス数を指定する物である。num_workers
の数だけworkerプロセスが立ち上げられる。気をつけなければならないのは、一つのイテレータないで並列化されるのではなく、それぞれのプロセスは別々のイテレータの前処理をおこなうという点だ。例えばnum_workers=5
とすると5個先のイテレータまでバックグラウンドで実行してくれる。
また、workerは常に投機的に実行し続けるので、同時に実行する前処理のプロセス数はnum_workers
で指定した数だが、早く終わるとさらに先の前処理までおこなうようになっている。
無限におこなうはずはないと思うが、どこまで投機的に実行するのかは追い切れていない。
所感としては、loopの中と前処理が毎回同じ実行コストで低コストあれば2で十分だが、ばらつきがある場合は、長くかかるloopや前処理のバッファを貯めるために少し多めに設定しても良いのではと思う。また、ほとんどの処理がシングルスレッドならGPUでの計算を邪魔しない程度に増やしても良いだろう。逆にデータのロードの際にIO waitがボトルネックになる場合は多くしないほうがよさそうではある。
sample code
import torch import torchvision import numpy as np import time def f(x): time.sleep(0.2) return 2*x class MyTransform: def __init__(self, a): self.a = a def __call__(self, x): return self.a * x class Mydatasets(torch.utils.data.Dataset): def __init__(self, leng, transform1 = None): self.transform1 = transform1 self.dataset = list(np.arange(leng).reshape(leng,1,1)) self.datanum = leng def __len__(self): return self.datanum def __getitem__(self, idx): out_data = self.dataset[idx] if self.transform1: out_data = self.transform1(out_data) return out_data def __str__(self): return f"""my original dataloader. data num: {self.datanum}""" if __name__ == "__main__": trans = torchvision.transforms.Compose([f, MyTransform(3), torchvision.transforms.ToTensor()]) dataset = Mydatasets(10, transform1=trans) print(dataset) trainloader = torch.utils.data.DataLoader(dataset, batch_size = 10, shuffle = False, num_workers = 2) for z in trainloader: print(z) # for check worker's behaviour for i, z in enumerate(trainloader): print(z[0][0][0][0]) if i%5 == 0 and i is not 0: time.sleep(10) else: time.sleep(0.1)