今回は、PyTorchのカスタムデータセットを作成する方法を学びます。
機械学習では、モデルは学習させたデータと同じだけ良いモデルになります。
MNISTやCIFAR、ImageNetなど、初心者の教育やベンチマークに使われるような、あらかじめ作られた標準的なデータセットがたくさん存在します。
しかし、このようなあらかじめ定義されたデータセットはあまり多くなく、比較的新しい問題に取り組んでいる場合、あらかじめ定義されたデータセットが得られず、独自のデータセットを使用して学習する必要があるかもしれません。
この記事では、PyTorchを使ってカスタムデータから初級レベルのデータセットセレーションを理解することにします。
この記事もチェック:Pythonでデータセットから別のデータセットへピボットテーブルを作成する
PyTorchのデータセットとDataLoaderクラスを理解する
データサンプルを処理するコードは煩雑になりやすく、メンテナンスも大変です。
読みやすさとモジュール化のために、データセットのコードはモデルの学習コードから切り離すのが理想的です。
PyTorchは2つのデータプリミティブを提供しています。
PyTorch には torch.utils.data.DataLoader と torch.utils.data.Dataset という2つのデータプリミティブがあり、ロード済みのデータセットと独自のデータを利用することができます。
Datasetにはサンプルとそれに対応するラベルが格納され、DataLoaderはサンプルに簡単にアクセスできるようにDataset` の周りをイテラブルにラップしている。
つまり、Dataset はデータをディスクからコンピュータで読み込める形式にロードする役割を担うクラスです。
DataLoaderやユーザがディスクからメモリにデータをロードする必要がある場合にのみ、メモリをロードします。
これは、すべての画像を一度にメモリに保存するのではなく、必要なときに読み込むため、メモリ効率がよいのです。
torch Datasetクラスは、データセットを表す抽象クラスです。
カスタムデータセットを作るには、この抽象クラスを継承すればよい。
ただし、非常に重要な2つの関数を必ず定義してください。
-
__len__は、len(dataset)がデータセットの大きさを返すようにします。 -
__getitem__は、dataset[i]を使って i 番目のサンプルを取得できるようなインデックスをサポートします。
DataLoaderは、これらのメソッドを呼び出してメモリをロードするだけです。
この記事では、カスタムデータセットの作成にのみ焦点を当てます。
DataLoaderは非常に多くの拡張が可能ですが、この記事の範囲外です。
さて、DataLoaderとDatasetの基本的な機能を学んだので、実際の生活の中でどのように行われているかの例を見ていきます。
ラベルなし画像からカスタムデータセットを読み込む
これは比較的簡単な例で、あるフォルダにあるすべての画像をGAN学習用のデータセットとしてロードするものです。
すべてのデータは同じクラスのものなので、今のところラベリングは気にする必要はありません。
1. Custom Dataset クラスの初期化
# Importsimport os
from PIL import Image
from torch.utils.data import Dataset
from natsort import natsorted
from torchvision import datasets, transforms
# Define your own class LoadFromFolderclass LoadFromFolder(Dataset):
def __init__(self, main_dir, transform):
# Set the loading directory
self.main_dir = main_dir
self.transform = transform
# List all images in folder and count them
all_imgs = os.listdir(main_dir)
self.total_imgs = natsorted(all_imgs)
|
カスタムデータセットに特化した関数を2つ定義する必要があります。
2. 2.Defining _len_to_function
この関数は、カスタムデータセットから正常に読み込まれたアイテムの数を識別することができます。
def __len__(self):
# Return the previously computed number of images
return len(self.total_imgs)
|
3. 3. ⾳⼭の定義
def __getitem__(self, idx):
img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
# Use PIL for image loading
image = Image.open(img_loc).convert("RGB")
# Apply the transformations
tensor_image = self.transform(image)
return tensor_image
|
データセットの定義が完了したら、次はインスタンスの作成です。
を使用してインスタンスを作成します。
ラベル付き画像からカスタムデータセットを読み込む
例えば、猫と犬の分類器のようなもう少し複雑な問題があるとしましょう。
ここで、データセットの画像にラベルを付ける必要があります。
このために、特別なPyTorchデータセットクラスImageFolderがあります。
以下のようなディレクトリ構造になっているとします。
dataset = LoadFromFolder(main_dir="./data", transform=transform)
dataloader = DataLoader(dataset)
print(next(iter(dataloader)).shape) # prints shape of image with single batch
|
猫の画像はすべてcatというフォルダに、犬の画像はすべてdogsというフォルダに格納されています。
このようなディレクトリ構造でデータセットを作成する場合は、次のようにします。
from torchvision.datasets import ImageFolder
dataset = ImageFolder(root="./root", transform=transform)
dataloader = DataLoader(dataset)
print(next(iter(dataloader)).shape) # prints shape of image with single batch
|
ImageFolder クラスを継承することで、画像のラベル付けや読み込みの方法をいつでも変更することができます。
この記事もチェック:Pythonのcatboostの使い方|2値分類問題を解いてみた
カスタムオーディオデータセットの読み込み
オーディオを扱う場合も、同じテクニックが適用できます。
変わるのは、データセットの長さの測り方と、ファイルをメモリに読み込む方法だけです。
from torch.utils.data import Dataset
class SpectrogramDataset(Dataset):
def __init__(self,file_label_ds, transform, audio_path=""):
self.ds= file_label_ds
self.transform = transform
self.audio_path=audio_path
# The length of the dataset
def __len__(self):
return len(self.ds)
# Load of item in folder
def __getitem__(self, index):
file,label=self.ds[index]
x=self.transform(self.audio_path+file)
return x, file, label
|
# file_label_ds is a dataset that gives you the file name and label.dataset = SpectrogramDataset(file_label_ds, transform)
|
まとめ
以上で本記事は終了となります。
今後もDeep LearningとPyTorchに関する記事をお届けします。