PyTorchのGANを使ってMNISTトレーニングを行う方法

スポンサーリンク

GANは、2014年にグッドフェローが立ち上げて以来、話題になっています。この記事では、PyTorchで最初のGANをトレーニングすることを学びます。また、GANの内部動作の説明と、PyTorchによるGANの簡単な実装をウォークスルーするようにします。

スポンサーリンク

インポートするライブラリ

まず、実装で使用するライブラリや関数をインポートします。

import torch
from torch import nn
 
from torchvision import transforms
from torchvision.utils import make_grid
 
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
 
import matplotlib.pyplot as plt
from IPython.display import clear_output

GANとは?

生成ネットワークとは、簡単に言うと、学習データから学習し、学習データのようなデータを生成することができるネットワークです。生成モデルの設計には様々な方法があり、その1つが敵対的モデルです。

生成的敵対ネットワークでは、生成器と識別器の2つのサブモデルが存在します。これから、これらのサブモデルについて詳しく見ていく。

1. ジェネレーター

ジェネレータは、その名の通り、画像を生成する役割を担っている。

ジェネレーターは、小さな低次元入力(一般に1次元ベクトル)を受け取り、128x128x3次元の画像データを出力として与える。

この低次元を高次元に拡大する操作は、直列デコンボリューションとコンボリューションのレイヤーを使って達成される。

このジェネレーターは、低次元のデータを取り込み、高次元の画像データにマッピングする関数と考えることができる。

学習期間中、ジェネレータは低次元データから高次元データへのマッピングをより効果的に行う方法を学習します。

生成器の目標は、実画像の識別器を騙せるような画像を生成することです。

class Generator(nn.Module):
  def __init__(self, z_dim, im_chan, hidden_dim=64):
        super().__init__()
        self.z_dim = z_dim
        self.gen = nn.Sequential(
             
            # We define the generator as stacks of deconvolution layers
            # with batch normalization and non-linear activation function
            # You can try to play with the values of the layers
 
            nn.ConvTranspose2d(z_dim, 4*hidden_dim, 3, 2),
            nn.BatchNorm2d(4*hidden_dim),
            nn.ReLU(inplace=True),
 
            nn.ConvTranspose2d(hidden_dim * 4, hidden_dim * 2, 4, 1),
            nn.BatchNorm2d(hidden_dim*2),
            nn.ReLU(inplace=True),
 
            nn.ConvTranspose2d(hidden_dim * 2, hidden_dim ,3 ,2),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
             
            nn.ConvTranspose2d(hidden_dim, im_chan, 4, 2),
            nn.Tanh()
        )
  def forward(self, noise):
       
      # Define how the generator computes the output
 
      noise = noise.view(len(noise), self.z_dim, 1, 1)
      return self.gen(noise)

ジェネレータクラス。

# We define a generator with latent dimension 100 and img_dim 1
gen = Generator(100, 1)
print("Composition of the Generator:", end="

")

print(gen)
Compostion of the Generator:
 
Generator(
  (gen): Sequential(
    (0): ConvTranspose2d(100, 256, kernel_size=(3, 3), stride=(2, 2))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2))
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2))
    (10): Tanh()
  )
)
class Discriminator(nn.Module):
    def __init__(self, im_chan, hidden_dim=16):
         
        super().__init__()
        self.disc = nn.Sequential(
             
            # Discriminator is defined as a stack of
            # convolution layers with batch normalization
            # and non-linear activations.
 
            nn.Conv2d(im_chan, hidden_dim, 4, 2),
            nn.BatchNorm2d(hidden_dim),
            nn.LeakyReLU(0.2,inplace=True),
             
            nn.Conv2d(hidden_dim, hidden_dim * 2, 4, 2),
            nn.BatchNorm2d(hidden_dim*2),
            nn.LeakyReLU(0.2,inplace=True),
             
            nn.Conv2d(hidden_dim*2, 1, 4, 2)
        )
 
    def forward(self, image):
 
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)

補足説明 画像は非常に高次元のデータです。3x128x128 の RGB 画像でもサイズは 49152 です。

このような巨大な空間の部分空間あるいは多様体に、我々が求める画像が存在します。

理想的には、生成器がその部分空間の位置を学習し、学習した部分空間からランダムにサンプリングして出力を生成することです。

この理想的な部分空間の探索は非常に計算量の多い作業であり、これに対処するための最も一般的な方法は、プッシュフォワードを用いて潜在ベクトル空間をデータ空間にマッピングすることです。

2. 識別器

識別器Dはより単純であるが、重要な役割を担っている。識別器は、入力データが元の分布のものなのか、我々の生成器のものなのかを示す2値分類器です。理想的な識別器は、元の分布からのデータを真とし、Gからのデータを偽とするものです。

# We define a discriminator for one class classification
disc = Discriminator(1)
print("Composition of the Discriminator:", end="

")

print(disc)
Composition of the Discriminator:
 
Discriminator(
  (disc): Sequential(
    (0): Conv2d(1, 16, kernel_size=(4, 4), stride=(2, 2))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): Conv2d(32, 1, kernel_size=(4, 4), stride=(2, 2))
  )
def gen_loss(gen, disc, num_images, latent_dim, device):
    
    # Generate the the fake images
    noise = random_noise(num_images, latent_dim).to(device)
    gen_img = gen(noise)
     
    # Pass through discriminator and find the binary cross entropy loss
    disc_gen = disc(gen_img)
    gen_loss = Loss(disc_gen, torch.ones_like(disc_gen))
     
    return gen_loss
def disc_loss(gen, disc, real_images, num_images, latent_dim, device):
     
    # Generate the fake images
    noise = random_noise(num_images, latent_dim).to(device);
    img_gen = gen(noise).detach()
     
    # Pass the real and fake images through discriminator
    disc_gen = disc(img_gen)
    disc_real = disc(real_images)
     
    # Find loss for the generator and discriminator
    gen_loss  = Loss(disc_gen, torch.zeros_like(disc_gen))
    real_loss = Loss(disc_real, torch.ones_like(disc_real))
     
    # Average over the losses for the discriminator loss
    disc_loss = ((gen_loss + real_loss) /2).mean()
 
    return disc_loss
# Set the batch size
BATCH_SIZE = 512
 
# Download the data in the Data folder in the directory above the current folder
data_iter = DataLoader(
                MNIST('../Data', download=True, transform=transforms.ToTensor()),
                      batch_size=BATCH_SIZE,
                      shuffle=True)

GANにおける損失関数

ここで、生成器と識別器に対する損失を定義します。

1. ジェネレータの損失

Generatorは、識別器を騙して本物だと思わせるような画像を生成しようとします。

そのため、生成器は偽の画像を真のラベルに割り当てる確率を最大化しようとします。

そのため、生成器損失は生成された画像が偽物と判別される確率の期待値です。

# Set Loss as Binary CrossEntropy with logits
Loss = nn.BCEWithLogitsLoss()
# Set the latent dimension
latent_dim = 100
display_step = 500
# Set the learning rate
lr = 0.0002
 
# Set the beta_1 and beta_2 for the optimizer
beta_1 = 0.5
beta_2 = 0.999

2. 識別器の損失

実画像に真のラベルを付与する確率を最大にし、偽画像に偽のラベルを付与する確率を最大にする識別器を求める。

生成器損失と同様に、識別器損失は実画像が偽画像に分類される確率と偽画像が実画像に分類される確率です。

この2つのモデルの損失関数が互いにどのように作用しているかに注目しよう。

device = "cpu"
if torch.cuda.is_available():
  device = "cuda"
device

MNISTトレーニングデータセットのロード

MNISTの学習用データをロードします。必要なデータセットのダウンロードには torchvision パッケージを使用します。

# Initialize the Generator and the Discriminator along with
# their optimizer gen_opt and disc_opt
# We choose ADAM as the optimizer for both models
gen = Generator(latent_dim, 1).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
disc = Discriminator(1 ).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))
 
 
# Initialize the weights of the various layers
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
 
# Apply the initial weights on the generator and discriminator
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

モデルの初期化

モデルのハイパーパラメータを設定します。

def display_images(image_tensor, num_images=25, size=(1, 28, 28)):
 
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

ハードウェアアクセラレーションを有効にしているかどうかに応じて、デバイスを cpu または cuda に設定します。

def random_noise(n_samples, z_dim):
  return torch.randn(n_samples, z_dim)

ジェネレータ、識別器、オプティマイザを初期化します。また、レイヤーの開始/初期重みを初期化します。

# Set the number of epochs
num_epochs = 100
# Set the interval at which generated images will be displayed
display_step = 100
# Inter parameter
itr = 0
 
for epoch in range(num_epochs):
  for images, _ in data_iter:
    
   num_images = len(images)
   # Transfer the images to cuda if harware accleration is present
   real_images = images.to(device)
    
   # Discriminator step
   disc_opt.zero_grad()
   D_loss = disc_loss(gen, disc, real_images, num_images, latent_dim, device)
   D_loss.backward(retain_graph=True)
   disc_opt.step()
    
   # Generator Step
   gen_opt.zero_grad()
   G_loss = gen_loss(gen, disc, num_images, latent_dim, device)
   G_loss.backward(retain_graph=True)
   gen_opt.step()
 
   if itr% display_step ==0 :
    with torch.no_grad():
      # Clear the previous output
      clear_output(wait=True)
      noise =  noise = random_noise(25,latent_dim).to(device)
      img = gen(noise)
      # Display the generated images
      display_images(img)
  itr+=1

ユーティリティ機能の設定

私たちは常に、アプリケーションに特化したものではないが、いくつかの作業を容易にするユーティリティ関数が必要です。ここでは、torchvision の make_grid 関数を利用して、画像をグリッド状に表示する関数を定義します。

Generator Logic
Fig 1: Working of the Generator

ジェネレータの入力となるランダムなノイズを生成するためのノイズ関数を定義しています。

Discriminator Logic
Fig 2: Working of the disrciminator

PyTorchでGANを学習させるループ

Model Working
Fig 3: Working of the model

結果

GANの結果の一部を紹介します。

まとめ

画像の集合から新しい画像を生成する方法について見てきました。GAN は数字の画像に限定されない。最近のGANは、人間の顔をリアルに生成できるほど強力です。GANは現在、音楽、芸術などの生成に使われている。GANの仕組みについてもっと知りたい方は、GoodfellowによるオリジナルのGANの論文を参照してください。

タイトルとURLをコピーしました