今回は、PyTorch Lightningを使って、最初のモデルを学習してみます。
PyTorchは2016年の登場以来、多くの研究者に選ばれています。
よりpythonicなアプローチと、CUDAの非常に強力なサポートにより、人気を博しました。
しかし、ボイラープレート・コードによる根本的な問題があります。
複数のGPUを使った分散学習など、一部の機能はパワーユーザー向けのものです。
PyTorch lightningはPyTorchのラッパーで、PyTorchの柔軟性を奪うことなくKerasのようなインターフェイスを与えることを目的としています。
すでにPyTorchを日常的に使用している場合、PyTorch-lightningはあなたのツールセットへの良い追加になります。
PyTorch Lightningをはじめよう
ここでは、最初のモデルを作成するためのステップをわかりやすく説明します。
それではさっそく始めてみましょう。
1. PyTorch Lightningのインストール
PyTorch-lightningをインストールするには、簡単なpipコマンドを実行します。
また、あらかじめ定義されたデータセットから始めたい場合は、lightning boltsモジュールが便利です。
pip install pytorch-lightning lightning-bolts
|
この記事もチェック:PyTorchで音声や画像のカスタムデータセットを自作する方法
2. モジュールのインポート
まず、pytorch と pytorch-lightning モジュールをインポートします。
import torch
from torch.nn import functional as F
from torch import nn
import pytorch_lightning as pl
|
よくある質問です。
“すでにlightningを使っているのに、なぜtorchが必要なのか?” という疑問があるかもしれません。
Lightningはtorchでのコーディングをより速くします。
torchの上に構築されたlightningはtorchモジュールで簡単に拡張でき、ユーザーは必要なときにアプリケーション固有の重要な変更を行うことができます。
3. MNISTデータセットのセットアップ
PyTorchと異なり、Lightningはデータベースコードをよりアクセスしやすく、整理しています。
DataModuleは、train_dataloader, val_dataloader(s), test_dataloader(s) と、必要な変換やデータ処理・ダウンロードのステップを集めただけのものです。
PyTorchでは、MNIST DataModuleは一般に次のように定義される。
from torchvision import datasets, transforms
# transforms # prepare transforms standard to MNIST transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(( 0.1307 ,), ( 0.3081 ,))])
mnist_train = MNIST(os.getcwd(), train = True , download = True , transform = transform)
mnist_train = DataLoader(mnist_train, batch_size = 64 )
|
見ての通り、DataModuleは1つのブロックに構成されているわけではありません。
もし、データ準備ステップや検証データローダーのような機能を追加したい場合、コードはより複雑になります。
Lightningはコードを LightningDataModule
クラスに整理しています。
PyTorch-LightningでDataModuleを定義する
1. データセットのセットアップ
まず、LightningDataModule
を使用してデータセットをロードし、セットアップします。
from torchvision.datasets import MNIST
from torchvision import transforms
class MNISTDataModule(pl.LightningDataModule):
def __init__( self , data_dir: str = './' ):
super ().__init__()
self .data_dir = data_dir
self .transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(( 0.1307 ,), ( 0.3081 ,))
])
# self.dims is returned when you call dm.size()
# Setting default dims here because we know them.
# Could optionally be assigned dynamically in dm.setup()
self .dims = ( 1 , 28 , 28 )
def prepare_data( self ):
# download
MNIST( self .data_dir, train = True , download = True )
MNIST( self .data_dir, train = False , download = True )
def setup( self , stage = None ):
# Assign train/val datasets for use in dataloaders
if stage = = 'fit' or stage is None :
mnist_full = MNIST( self .data_dir, train = True , transform = self .transform)
self .mnist_train, self .mnist_val = random_split(mnist_full, [ 55000 , 5000 ])
# Assign test dataset for use in dataloader(s)
if stage = = 'test' or stage is None :
self .mnist_test = MNIST( self .data_dir, train = False , transform = self .transform)
|
preapre_data関数は、データをダウンロードし、torch が読める形式で保存します。
setup 関数は、データセットを train、test、validation に分割します。
これらの関数は、データがどの程度の前処理を必要とするかによって、任意に複雑なものにすることができる。
2. DataLoaderの定義
さて、設定ができたので、データローダーの関数を追加していきます。
def train_dataloader( self ):
return DataLoader( self .mnist_train, batch_size = 32 )
def val_dataloader( self ):
return DataLoader( self .mnist_val, batch_size = 32 )
def test_dataloader( self ):
return DataLoader( self .mnist_test, batch_size = 32 )
|
3. MNIST DataModule の最終的な外観
最終的な LightningDataModule
は以下のようなものです。
class MNISTDataModule(pl.LightningDataModule):
def __init__( self , data_dir: str = './' ):
super ().__init__()
self .data_dir = data_dir
self .transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(( 0.1307 ,), ( 0.3081 ,))
])
# self.dims is returned when you call dm.size()
# Setting default dims here because we know them.
# Could optionally be assigned dynamically in dm.setup()
self .dims = ( 1 , 28 , 28 )
def prepare_data( self ):
# download
MNIST( self .data_dir, train = True , download = True )
MNIST( self .data_dir, train = False , download = True )
def setup( self , stage = None ):
# Assign train/val datasets for use in dataloaders
if stage = = 'fit' or stage is None :
mnist_full = MNIST( self .data_dir, train = True , transform = self .transform)
self .mnist_train, self .mnist_val = random_split(mnist_full, [ 55000 , 5000 ])
# Assign test dataset for use in dataloader(s)
if stage = = 'test' or stage is None :
self .mnist_test = MNIST( self .data_dir, train = False , transform = self .transform)
def train_dataloader( self ):
return DataLoader( self .mnist_train, batch_size = 32 )
def val_dataloader( self ):
return DataLoader( self .mnist_val, batch_size = 32 )
def test_dataloader( self ):
return DataLoader( self .mnist_test, batch_size = 32 )
|
MNIST データモジュールは PyTorch-bolts のデータモジュールであらかじめ定義されています。
もし、自分でコードを全部書くのが面倒なら、データモジュールをインポートして、それを使って作業を始めることができます。
from pl_bolts.datamodules import MNISTDataModule
# Create MNIST DataModule instance data_module = MNISTDataModule()
|
さて、データが手に入りましたので、学習用のモデルが必要です。
マルチパーセプトロンモデルの作成
ライティングモデルはPyTorchのモデルクラスと非常に似ていますが、学習を容易にするためにいくつかの特別なクラス関数を持っていることが特徴です。
init__と
forward`メソッドはPyTorchと全く同じです。
ここでは、3層の知覚を作成しており、各層の知覚の数は(128, 256, 10)となっています。
また、サイズ28 * 28 (784)の入力層があり、28×28のMNIST画像を平坦化したものを取り込んでいます。
1. PyTorchのような基本モデル
class MyMNISTModel(nn.Module):
def __init__( self ):
super ().__init__()
# mnist images are (1, 28, 28) (channels, width, height)
self .layer_1 = nn.Linear( 28 * 28 , 128 )
# The hidden layer of size 256
self .layer_2 = nn.Linear( 128 , 256 )
# 3rd hidden layer of size 10.
# This the prediction layer
self .layer_3 = nn.Linear( 256 , 10 )
def forward( self , x):
batch_size, channels, width, height = x.size()
# Flatten the image into a linear tensor
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, - 1 )
# Pass the tensor through the layers
x = self .layer_1(x)
x = F.relu(x)
x = self .layer_2(x)
x = F.relu(x)
x = self .layer_3(x)
# Softmax the values to get a probability
x = F.log_softmax(x, dim = 1 )
return x
|
このモデルが動くかどうか、ランダムな値(28, 28)を使って確認してみましょう。
net = MyMNISTModel()
x = torch.randn( 1 , 1 , 28 , 28 )
print (net(x).shape)
|
結果は、以下の通りになります。
torch.Size([1, 10]) |
1はバッチを、10は出力クラスの数を示します。
つまり、我々のモデルは正常に動作している。
2. 初期化関数と転送関数の定義
PyTorch DataModuleはpl.LightningModule
からプロパティを派生させる以外は全く同じように見えるでしょう。
Lightningネットワークは以下のようになります。
class MyMNISTModel(pl.LightningModule):
def __init__( self ):
super ().__init__()
...
def forward( self , x):
....
|
これらの基本的なトーチ関数に加え、lightingはトレーニング、テスト、検証のループの中で起こることを定義できる関数を提供します。
2. トレーニングループとバリデーションループの定義
モデルの学習と検証のための学習ループを定義します。
def training_step( self , batch, batch_idx):
x, y = batch
# Pass through the forward function of the network
logits = self (x)
loss = F.nll_loss(logits, y)
return loss
def validation_step( self , batch, batch_idx):
x, y = batch
logits = self (x)
loss = F.nll_loss(logits, y)
return loss
def test_step( self , batch, batch_idx):
x, y = batch
logits = self (x)
loss = F.nll_loss(logits, y)
y_hat = torch.argmax(logits, dim = 1 )
accuracy = torch. sum (y = = y_hat).item() / ( len (y) * 1.0 )
output = dict ({
'test_loss' : loss,
'test_acc' : torch.tensor(accuracy),
})
return output
|
3. オプティマイザー
Lightning モデルでは、モデル定義の内部で特定のモデルに対するオプティマイザを定義することができます。
# We are using the ADAM optimizer for this tutorial def configure_optimizers( self ):
return torch.optim.Adam( self .parameters(), lr = 1e - 3 )
|
4. モデルの最終的な外観
最終的なLightningモデルの外観は、以下のようになります。
class MyMNISTModel(pl.LightningModule):
def __init__( self ):
super ().__init__()
# mnist images are (1, 28, 28) (channels, width, height)
self .layer_1 = nn.Linear( 28 * 28 , 128 )
# The hidden layer of size 256
self .layer_2 = nn.Linear( 128 , 256 )
# 3rd hidden layer of size 10.
# This the prediction layer
self .layer_3 = nn.Linear( 256 , 10 )
def forward( self , x):
batch_size, channels, width, height = x.size()
# Flatten the image into a linear tensor
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, - 1 )
# Pass the tensor through the layers
x = self .layer_1(x)
x = F.relu(x)
x = self .layer_2(x)
x = F.relu(x)
x = self .layer_3(x)
# Softmax the values to get a probability
x = F.log_softmax(x, dim = 1 )
return x
def training_step( self , batch, batch_idx):
x, y = batch
# Pass through the forward function of the network
logits = self (x)
loss = F.nll_loss(logits, y)
return loss
def validation_step( self , batch, batch_idx):
x, y = batch
logits = self (x)
loss = F.nll_loss(logits, y)
return loss
def test_step( self , batch, batch_idx):
x, y = batch
logits = self (x)
loss = F.nll_loss(logits, y)
y_hat = torch.argmax(logits, dim = 1 )
accuracy = torch. sum (y = = y_hat).item() / ( len (y) * 1.0 )
output = dict ({
'test_loss' : loss,
'test_acc' : torch.tensor(accuracy),
})
return output
def configure_optimizers( self ):
return torch.optim.Adam( self .parameters(), lr = 1e - 3 )
|
これでデータとモデルの準備は完了です。
それでは、データを使ってモデルの学習を始めましょう。
5. モデルの学習
損失を求め、バックワードパスを行う従来の定型的なループの代わりに、pytorch-lightningモジュールのTrainerが多くのコードなしに私たちのために仕事をしてくれます。
まず、lightningのTrainerを特定のパラメータで初期化します。
from pytorch_lightning import Trainer
# Set gpus = 0 for training on cpu # Set the max_epochs for maximum number of epochs you want trainer = Trainer(gpus = 1 , max_epochs = 20 )
|
MNISTDataModuleを使用してデータセットをフィットさせます。
trainer.fit(net, data_module) |
trainer.test(test_dataloaders = data_module.train_dataloader())
|
6. 結果
最終的な精度をトレーニングデータセットで確認しましょう。
-------------------------------------------------------------------------------- DATALOADER:0 TEST RESULTS {'test_acc': tensor(.98), 'test_loss': tensor(0.0017, device='cuda:0')} -------------------------------------------------------------------------------- |
結果は、以下の通りになります。
trainer.test(test_dataloaders = data_module.test_dataloader())
|
トレーニングデータセットで高い精度を得ることは、オーバーフィッティングの可能性があります。
そこで、先ほど分離したテストデータセットでモデルをテストすることも必要です。
検証用データセットでモデルの最終的な精度を確認してみましょう。
-------------------------------------------------------------------------------- DATALOADER:0 TEST RESULTS {'test_acc': tensor(.96), 'test_loss': tensor(0.0021, device='cuda:0')} -------------------------------------------------------------------------------- |
結果は以下の通りです。
これらの結果から、モデルがデータに対してうまく学習したことが確認できます。
まとめ
以上で、PyTorch-lightningのチュートリアルを終了します。
PyTorch-lightningは比較的新しく、急速に発展しているため、近い将来さらに多くの機能が追加されることが予想されます。
機械学習やディープラーニングに関するこのような記事を今後も期待してください。
この記事もチェック:Pythonで機械学習をライブラリや概要をまとめてみた