PyTorchで音声や画像のカスタムデータセットを自作する方法

スポンサーリンク

今回は、PyTorchのカスタムデータセットを作成する方法を学びます。

機械学習では、モデルは学習させたデータと同じだけ良いモデルになります。

MNISTやCIFAR、ImageNetなど、初心者の教育やベンチマークに使われるような、あらかじめ作られた標準的なデータセットがたくさん存在します。

しかし、このようなあらかじめ定義されたデータセットはあまり多くなく、比較的新しい問題に取り組んでいる場合、あらかじめ定義されたデータセットが得られず、独自のデータセットを使用して学習する必要があるかもしれません。

この記事では、PyTorchを使ってカスタムデータから初級レベルのデータセットセレーションを理解することにします。

スポンサーリンク

PyTorchのデータセットとDataLoaderクラスを理解する

データサンプルを処理するコードは煩雑になりやすく、メンテナンスも大変です。

読みやすさとモジュール化のために、データセットのコードはモデルの学習コードから切り離すのが理想的です。

PyTorchは2つのデータプリミティブを提供しています。

 PyTorch には torch.utils.data.DataLoadertorch.utils.data.Dataset という2つのデータプリミティブがあり、ロード済みのデータセットと独自のデータを利用することができます。

 Datasetにはサンプルとそれに対応するラベルが格納され、DataLoaderはサンプルに簡単にアクセスできるようにDataset` の周りをイテラブルにラップしている。

つまり、Dataset はデータをディスクからコンピュータで読み込める形式にロードする役割を担うクラスです。

DataLoaderやユーザがディスクからメモリにデータをロードする必要がある場合にのみ、メモリをロードします。

これは、すべての画像を一度にメモリに保存するのではなく、必要なときに読み込むため、メモリ効率がよいのです。

torch Datasetクラスは、データセットを表す抽象クラスです。

カスタムデータセットを作るには、この抽象クラスを継承すればよい。

ただし、非常に重要な2つの関数を必ず定義してください。

  • __len__ は、 len(dataset) がデータセットの大きさを返すようにします。
  • __getitem__ は、 dataset[i] を使って i 番目のサンプルを取得できるようなインデックスをサポートします。

DataLoaderは、これらのメソッドを呼び出してメモリをロードするだけです。

この記事では、カスタムデータセットの作成にのみ焦点を当てます。

DataLoaderは非常に多くの拡張が可能ですが、この記事の範囲外です。

さて、DataLoaderDatasetの基本的な機能を学んだので、実際の生活の中でどのように行われているかの例を見ていきます。

ラベルなし画像からカスタムデータセットを読み込む

これは比較的簡単な例で、あるフォルダにあるすべての画像をGAN学習用のデータセットとしてロードするものです。

すべてのデータは同じクラスのものなので、今のところラベリングは気にする必要はありません。

1. Custom Dataset クラスの初期化

# Imports
import os
from PIL import Image
from torch.utils.data import Dataset
from natsort import natsorted
from torchvision import datasets, transforms
 
# Define your own class LoadFromFolder
class LoadFromFolder(Dataset):
    def __init__(self, main_dir, transform):
         
        # Set the loading directory
        self.main_dir = main_dir
        self.transform = transform
         
        # List all images in folder and count them
        all_imgs = os.listdir(main_dir)
        self.total_imgs = natsorted(all_imgs)

カスタムデータセットに特化した関数を2つ定義する必要があります。

2. 2.Defining _len_to_function

この関数は、カスタムデータセットから正常に読み込まれたアイテムの数を識別することができます。

def __len__(self):
    # Return the previously computed number of images
    return len(self.total_imgs)

3. 3. ⾳⼭の定義

def __getitem__(self, idx):
    img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
     
    # Use PIL for image loading
    image = Image.open(img_loc).convert("RGB")
    # Apply the transformations
    tensor_image = self.transform(image)
    return tensor_image

データセットの定義が完了したら、次はインスタンスの作成です。

を使用してインスタンスを作成します。

ラベル付き画像からカスタムデータセットを読み込む

例えば、猫と犬の分類器のようなもう少し複雑な問題があるとしましょう。

ここで、データセットの画像にラベルを付ける必要があります。

このために、特別なPyTorchデータセットクラスImageFolderがあります。

以下のようなディレクトリ構造になっているとします。

dataset = LoadFromFolder(main_dir="./data", transform=transform)
dataloader = DataLoader(dataset)
print(next(iter(dataloader)).shape)  # prints shape of image with single batch

猫の画像はすべてcatというフォルダに、犬の画像はすべてdogsというフォルダに格納されています。

このようなディレクトリ構造でデータセットを作成する場合は、次のようにします。

from torchvision.datasets import ImageFolder
dataset = ImageFolder(root="./root", transform=transform)
dataloader = DataLoader(dataset)
print(next(iter(dataloader)).shape)  # prints shape of image with single batch

ImageFolder クラスを継承することで、画像のラベル付けや読み込みの方法をいつでも変更することができます。

カスタムオーディオデータセットの読み込み

オーディオを扱う場合も、同じテクニックが適用できます。

変わるのは、データセットの長さの測り方と、ファイルをメモリに読み込む方法だけです。

from torch.utils.data import Dataset
 
class SpectrogramDataset(Dataset):
 
    def __init__(self,file_label_ds,  transform, audio_path=""):
        self.ds= file_label_ds
        self.transform = transform
        self.audio_path=audio_path
     
    # The length of the dataset
    def __len__(self):
        return len(self.ds)
 
    # Load of item in folder
    def __getitem__(self, index):
        file,label=self.ds[index]
        x=self.transform(self.audio_path+file)
        return x, file, label
# file_label_ds is a dataset that gives you the file name and label.
dataset = SpectrogramDataset(file_label_ds, transform)

まとめ

以上で本記事は終了となります。

今後もDeep LearningとPyTorchに関する記事をお届けします。

タイトルとURLをコピーしました