DataLoader の transform に target を渡す
はじめに
torchvision の transforms compose に target も渡して DataLoader の中で同時に加工しようと試みて、うまくいきませんでした。
対策を記します。
やろうとしたこと
解析は Semantic Segmentation。 データ拡張のための加工を PyTorch の DataLoader で行おうと考えました。
まず、PyTorch のチュートリアルを参照して、compose に必要な処理をまとめて…
import torchvision.transforms as transforms
process_list = [自作処理いろいろ] # 画像データと bounding box を加工する処理を自作した
transform_train = transforms.Compose(process_list) # 処理をまとめる
# 上の transform_train を dataset で呼び出したらエラーが発生した。
分類問題と同じように解析を走らせて… とやったらエラーが発生。
TypeError: __call__() takes 2 positional arguments but 3 were given
引数が合っていないという内容でした。
なぜうまくいかないのか
エラーの原因がわからないので、torchvision の transforms Compose のソースコードを確認しました。
__call__ メソッドを見ると、引数は img のみ。
つまり、Compose は画像データのデータ拡張のみのための関数であり、target (今回は bboxes)などの追加引数の引き渡しは一切考慮されていないことが分かりました。
解決法
自分は、PyTorch のチュートリアルを参考にしながらコードを書いていたのになぜこんなことに?
image と target を同時に加工したいならば、自分で Compose クラスを実装して、使用する方針がよさそうです。
コード内に以下のコードを差しはさみます。
class Compose:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img, **kwargs):
for t in self.transforms:
img, kwargs = t(img, **kwargs)
return img, kwargs
感想
そもそも、なぜ詰まったかというと、チュートリアルで使用している compose が普段使用している torchvision transpose compose ではないということに気が付かなかったためです。
当然、torchvision だと考えていたのが間違いでした。
ちょっと騙された気分です。
こんな簡単なことで、ずいぶんと時間を浪費しました。
2020 07 05