DataLoader の transform に target (bboxesなど) を渡す方法

はじめに

torchvision の transforms compose に target も渡してDataLoader の中で同時に加工しようと試みて、うまくいきませんでした。

結論を記します。

やろうとしたこと

解析は Semantic Segmentation。

データ拡張のための加工を PyTorch の DataLoader で行おうと考えました。

まず、PyTorch のチュートリアルを参照して、compose に必要な処理をまとめて…

import torchvision.transforms as transforms

process_list = [自作処理いろいろ]  # 画像データと bounding box を加工する処理を自作した

transform_train = transforms.Compose(process_list)  # 処理をまとめる
# 上の transform_train を dataset で呼び出したらエラーが発生した。

分類問題と同じように解析を走らせて… とやったらエラーが発生。

TypeError: __call__() takes 2 positional arguments but 3 were given

引数が合っていないという内容でした。

なぜうまくいかないのか

エラーの原因がわからないので、torchvision の transforms Compose のソースコードを確認しました。

__call__ メソッドを見ると、引数は img のみ。
つまり、Compose は画像データのデータ拡張のみのための関数であり、target (今回はbboxes)などの追加引数の引き渡しは一切考慮されていないことが分かりました。

解決法

自分は、PyTorch のチュートリアルを参考にしながらコードを書いていたのになぜこんなことに?

image と target を同時に加工したいならば、自分で Compose クラスを実装して、使用する方針がよさそう。

コード内に以下のコードを差しはさみます。

class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img, **kwargs):
        for t in self.transforms:
            img, kwargs = t(img, **kwargs)
        return img, kwargs

感想

そもそも、なぜ詰まったかというと、チュートリアルで使用しているcompose が普段使用している torchvision transpose compose ではないということに気が付かなかったためです。

当然、torchvision だと考えていたのが間違いでした。
ちょっと騙された気分です。

こんな簡単なことで、ずいぶんと時間を浪費しました。

コメント

タイトルとURLをコピーしました