半教師付き学習(SSL)を用いてラベルのないデータから分類器モデルを学習するための初心者向けチュートリアルです。

スポンサーリンク

従来、分類器のようなコンピュータビジョンモデルの学習には、ラベル付けされたデータが必要であった。

学習データの各例は、画像と、その画像を説明する人間が作成したラベルのペアである必要がありました。

近年、新しいSSL技術は、Imagenetのような古典的な課題に対して、コンピュータビジョンで最も正確なモデルを提供することができるようになりました。

半教師付き学習(SSL)は、ラベル付けされたデータとラベル付けされていないデータの両方からモデルを学習させるものです。

非ラベル化データは、ラベルのない画像のみから構成される。

SSLが優れているのは、通常、ラベル付きデータよりもラベルなしデータの方が多いからです。

特に、モデルを実運用に配備した後は、ラベルなしデータの方が多くなります。

また、SSLはラベリングにかかる時間、コスト、労力を削減します。

しかし、ラベルのない画像からモデルはどのように学習するのでしょうか?重要な洞察は、画像そのものが情報を持っているということです。

SSLの魔法は、その構造に基づいて類似している画像を自動的にクラスタリングすることにより、ラベルのないデータから情報を抽出できることであり、このクラスタリングはモデルが学習するための追加情報を提供します。

この記事では、Google Colabに含まれるmatplotlib、numpy、TensorFlowなどのいくつかの一般的なPythonライブラリを使用します。

もしそれらをインストールする必要があるなら、通常Jupyterノートブック内で !pip install --upgrade pip; pip install matplotlib numpy tensorflow またはコマンドラインから pip install --upgrade pip; pip install matplotlib numpy tensorflow (exclamation pointなし) を実行することが可能です。

Google Colabを使用している場合は、ランタイムタイプをGPUに変更することを確認してください。

この記事では、CIFAR-10データセットで分類器を学習させましょう。

これは自然画像からなる古典的な研究データセットです。

ロードして見てみましょう。

CIFAR-10 のクラスは、カエル、ボート、車、トラック、鹿、馬、鳥、猫、犬、そして飛行機です。

import matplotlib.pyplot as plt
 
def plot_images(images):
  """Simple utility to render images."""
  # Visualize the data.
  _, axarr = plt.subplots(5, 5, figsize=(15,15))
 
  for row in range(5):
    for col in range(5):
      image = images[row*5 + col]
      axarr[row, col].imshow(image)
       
import tensorflow as tf
 
NUM_CLASSES = 10
# Load the data using the Keras Datasets API.
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
 
plot_images(x_test)
def get_model():   
    return tf.keras.applications.MobileNet(input_shape=(32,32,3),
                                           weights=None,
                                           classes=NUM_CLASSES,
                                           classifier_activation=None)
 
model = get_model()
スポンサーリンク

モデルの作成

一般に、モデル・アーキテクチャは既製のものを使用するのがよいでしょう。

そうすれば、モデルアーキテクチャの設計に頭を悩ませる手間が省けます。

モデルサイズの一般的なルールは、データを処理するのに十分な大きさで、かつ推論時に遅くならない大きさのモデルを選択することです。

CIFAR-10のような非常に小さなデータセットには、非常に小さなモデルを使用します。

画像サイズの大きなデータセットには、Efficient Netファミリーが適しています。

def normalize_data(x_train, y_train, x_test, y_test):
  """Utility to normalize the data into standard formats."""
 
  # Update the pixel range to [-1,1], which is expected by the model architecture.
  x_train = x = tf.keras.applications.mobilenet.preprocess_input(x_train)
  x_test = x = tf.keras.applications.mobilenet.preprocess_input(x_test)
 
  # Convert to one-hot labels.
  y_train = tf.keras.utils.to_categorical(y_train, NUM_CLASSES)
  y_test = tf.keras.utils.to_categorical(y_test, NUM_CLASSES)
 
  return x_train, y_train, x_test, y_test
   
x_train, y_train, x_test, y_test =
  normalize_data(x_train, y_train, x_test, y_test)

データを用意する

それでは、10種類の物体を表す0から9の整数であるラベルを、[1,0,0,0,0,0,0,0]や[0,0,0,0,0,1]という一発のベクトルに変換してデータを用意しましょう。

また、画像のピクセルをモデルアーキテクチャが期待する範囲、すなわち[-1, 1]に更新することにします。

import numpy as np
 
def prepare_data(x_train, y_train, num_labeled_examples, num_unlabeled_examples):
    """Returns labeled and unlabeled datasets."""
    num_examples = x_train.size
 
    assert num_labeled_examples + num_unlabeled_examples <= num_examples
 
    # Generate some random indices.
    dataset_size = len(x_train)
    indices = np.array(range(dataset_size))
    generator = np.random.default_rng(seed=0)
    generator.shuffle(indices)
 
    # Split the indices into two sets: one for labeled, one for unlabeled.
    labeled_train_indices = indices[:num_labeled_examples]
    unlabeled_train_indices = indices[num_labeled_examples : num_labeled_examples + num_unlabeled_examples]
 
    x_labeled_train = x_train[labeled_train_indices]
    y_labeled_train = y_train[labeled_train_indices]
 
    x_unlabeled_train = x_train[unlabeled_train_indices]
    # Since this is unlabeled, we won't need a y_labeled_data.
 
    return x_labeled_train, y_labeled_train, x_unlabeled_train
 
NUM_LABELED = 5000
NUM_UNLABELED = 20000
 
x_labeled_train, y_labeled_train, x_unlabeled_train =
    prepare_data(x_train,
                 y_train,
                 num_labeled_examples=NUM_LABELED,
                 num_unlabeled_examples=NUM_UNLABELED)
 
del x_train, y_train

このデータセットには50,000の例が含まれている。

そのうち5,000例をラベル付き画像とし、20,000例をラベルなし画像として使用することにします。

model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.CategoricalAccuracy()],
)
 
# Setup Keras augmentation.
datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    featurewise_center=False,
    featurewise_std_normalization=False,
    horizontal_flip=True)
 
datagen.fit(x_labeled_train)
 
batch_size = 64
epochs = 30
model.fit(
    x = datagen.flow(x_labeled_train, y_labeled_train, batch_size=batch_size),
    shuffle=True,
    validation_data=(x_test, y_test),
    batch_size=batch_size,
    epochs=epochs,
)
baseline_metrics = model.evaluate(x=x_test, y=y_test, return_dict=True)
print('')
print(f"Baseline model accuracy: {baseline_metrics['categorical_accuracy']}")

ベースライン・トレーニング

SSLによる性能向上を測定するために、まず、SSLを使用しない標準的なトレーニングループでモデルの性能を測定してみましょう。

いくつかの基本的なデータ補強を行った標準的なトレーニングループをセットアップしましょう。

データ増強は正則化の一種で、オーバーフィッティングを防ぎ、モデルが見たことのないデータに対してよりよく汎化できるようにします。

以下のハイパーパラメータ値(学習率、エポック、バッチサイズなど)は、一般的なデフォルト値と手動でチューニングした値の組み合わせです。

その結果、約45%の精度のモデルが出来上がりました。

(学習精度ではなく、検証精度を読むことを忘れないでください)。

次の課題は、SSLを使用してモデルの精度を向上させることができるかどうかを見極めることです。

Epoch 1/30
79/79 [==============================] - 4s 23ms/step - loss: 2.4214 - categorical_accuracy: 0.1578 - val_loss: 2.3047 - val_categorical_accuracy: 0.1000
Epoch 2/30
79/79 [==============================] - 1s 16ms/step - loss: 2.0831 - categorical_accuracy: 0.2196 - val_loss: 2.3063 - val_categorical_accuracy: 0.1000
Epoch 3/30
79/79 [==============================] - 1s 16ms/step - loss: 1.9363 - categorical_accuracy: 0.2852 - val_loss: 2.3323 - val_categorical_accuracy: 0.1000
Epoch 4/30
79/79 [==============================] - 1s 16ms/step - loss: 1.8324 - categorical_accuracy: 0.3174 - val_loss: 2.3496 - val_categorical_accuracy: 0.1000
Epoch 5/30
79/79 [==============================] - 1s 16ms/step - loss: 1.8155 - categorical_accuracy: 0.3438 - val_loss: 2.3339 - val_categorical_accuracy: 0.1000
Epoch 6/30
79/79 [==============================] - 1s 15ms/step - loss: 1.6477 - categorical_accuracy: 0.3886 - val_loss: 2.3606 - val_categorical_accuracy: 0.1000
Epoch 7/30
79/79 [==============================] - 1s 15ms/step - loss: 1.6120 - categorical_accuracy: 0.4100 - val_loss: 2.3585 - val_categorical_accuracy: 0.1000
Epoch 8/30
79/79 [==============================] - 1s 16ms/step - loss: 1.5884 - categorical_accuracy: 0.4220 - val_loss: 2.1796 - val_categorical_accuracy: 0.2519
Epoch 9/30
79/79 [==============================] - 1s 18ms/step - loss: 1.5477 - categorical_accuracy: 0.4310 - val_loss: 1.8913 - val_categorical_accuracy: 0.3145
Epoch 10/30
79/79 [==============================] - 1s 15ms/step - loss: 1.4328 - categorical_accuracy: 0.4746 - val_loss: 1.7082 - val_categorical_accuracy: 0.3696
Epoch 11/30
79/79 [==============================] - 1s 16ms/step - loss: 1.4328 - categorical_accuracy: 0.4796 - val_loss: 1.7679 - val_categorical_accuracy: 0.3811
Epoch 12/30
79/79 [==============================] - 2s 20ms/step - loss: 1.3962 - categorical_accuracy: 0.5020 - val_loss: 1.8994 - val_categorical_accuracy: 0.3690
Epoch 13/30
79/79 [==============================] - 1s 16ms/step - loss: 1.3271 - categorical_accuracy: 0.5156 - val_loss: 2.0416 - val_categorical_accuracy: 0.3688
Epoch 14/30
79/79 [==============================] - 1s 17ms/step - loss: 1.2711 - categorical_accuracy: 0.5374 - val_loss: 1.9231 - val_categorical_accuracy: 0.3848
Epoch 15/30
79/79 [==============================] - 1s 15ms/step - loss: 1.2312 - categorical_accuracy: 0.5624 - val_loss: 1.9006 - val_categorical_accuracy: 0.3961
Epoch 16/30
79/79 [==============================] - 1s 19ms/step - loss: 1.2048 - categorical_accuracy: 0.5720 - val_loss: 2.0102 - val_categorical_accuracy: 0.4102
Epoch 17/30
79/79 [==============================] - 1s 16ms/step - loss: 1.1365 - categorical_accuracy: 0.6000 - val_loss: 2.1400 - val_categorical_accuracy: 0.3672
Epoch 18/30
79/79 [==============================] - 1s 18ms/step - loss: 1.1992 - categorical_accuracy: 0.5840 - val_loss: 2.1206 - val_categorical_accuracy: 0.3933
Epoch 19/30
79/79 [==============================] - 2s 25ms/step - loss: 1.1438 - categorical_accuracy: 0.6012 - val_loss: 2.4035 - val_categorical_accuracy: 0.4014
Epoch 20/30
79/79 [==============================] - 2s 24ms/step - loss: 1.1211 - categorical_accuracy: 0.6018 - val_loss: 2.0224 - val_categorical_accuracy: 0.4010
Epoch 21/30
79/79 [==============================] - 2s 21ms/step - loss: 1.0425 - categorical_accuracy: 0.6358 - val_loss: 2.2100 - val_categorical_accuracy: 0.3911
Epoch 22/30
79/79 [==============================] - 1s 16ms/step - loss: 1.1177 - categorical_accuracy: 0.6116 - val_loss: 1.9892 - val_categorical_accuracy: 0.4285
Epoch 23/30
79/79 [==============================] - 1s 19ms/step - loss: 1.0236 - categorical_accuracy: 0.6412 - val_loss: 2.1216 - val_categorical_accuracy: 0.4211
Epoch 24/30
79/79 [==============================] - 1s 18ms/step - loss: 0.9487 - categorical_accuracy: 0.6714 - val_loss: 2.0135 - val_categorical_accuracy: 0.4307
Epoch 25/30
79/79 [==============================] - 1s 16ms/step - loss: 1.1877 - categorical_accuracy: 0.5876 - val_loss: 2.3732 - val_categorical_accuracy: 0.3923
Epoch 26/30
79/79 [==============================] - 2s 20ms/step - loss: 1.0639 - categorical_accuracy: 0.6288 - val_loss: 1.9291 - val_categorical_accuracy: 0.4291
Epoch 27/30
79/79 [==============================] - 2s 19ms/step - loss: 0.9243 - categorical_accuracy: 0.6882 - val_loss: 1.8552 - val_categorical_accuracy: 0.4343
Epoch 28/30
79/79 [==============================] - 1s 15ms/step - loss: 0.9784 - categorical_accuracy: 0.6656 - val_loss: 2.0175 - val_categorical_accuracy: 0.4386
Epoch 29/30
79/79 [==============================] - 1s 17ms/step - loss: 0.9316 - categorical_accuracy: 0.6800 - val_loss: 1.9916 - val_categorical_accuracy: 0.4305
Epoch 30/30
79/79 [==============================] - 1s 17ms/step - loss: 0.8816 - categorical_accuracy: 0.7054 - val_loss: 2.0281 - val_categorical_accuracy: 0.4366
313/313 [==============================] - 1s 3ms/step - loss: 2.0280 - categorical_accuracy: 0.4366
 
Baseline model accuracy: 0.436599999666214

結果は、以下の通りになります。

!pip install --upgrade pip
!pip install masterful
 
import masterful
 
masterful = masterful.register()

SSLを使ったトレーニング

では、学習データにラベルのないデータを追加することで、モデルの精度を向上させることができるか見てみましょう。

今回の分類器のようなコンピュータビジョンモデルにSSLを実装するプラットフォームであるMasterfulを使用します。

Masterfulをインストールしましょう。

Google Colabでは、ノートブックセルからpip installすることができます

また、コマンドラインでもインストールできます。

詳しくは、Masterfulのインストールガイドをご覧ください。

Loaded Masterful version 0.4.1. This software is distributed free of
charge for personal projects and evaluation purposes.
See http://www.masterfulai.com/personal-and-evaluation-agreement for details.
Sign up in the next 45 days at https://www.masterfulai.com/get-it-now
to continue using Masterful.

結果は以下の通りです。

# Start fresh with a new model
tf.keras.backend.clear_session()
model = get_model()
 
# Tell Masterful that your model is performing a classification task
# with 10 labels and that the image pixel range is
# [-1,1]. Also, the model outputs logits rather than a softmax activation.
model_params = masterful.architecture.learn_architecture_params(
    model=model,
    task=masterful.enums.Task.CLASSIFICATION,
    input_range=masterful.enums.ImageRange.NEG_ONE_POS_ONE,
    num_classes=NUM_CLASSES,
    prediction_logits=True,
)
 
# Tell Masterful that your labeled training data is using one-hot labels.
labeled_training_data_params = masterful.data.learn_data_params(
    dataset=(x_labeled_train, y_labeled_train),
    task=masterful.enums.Task.CLASSIFICATION,
    image_range=masterful.enums.ImageRange.NEG_ONE_POS_ONE,
    num_classes=NUM_CLASSES,
    sparse_labels=False,
)
 
unlabeled_training_data_params = masterful.data.learn_data_params(
    dataset=(x_unlabeled_train,),
    task=masterful.enums.Task.CLASSIFICATION,
    image_range=masterful.enums.ImageRange.NEG_ONE_POS_ONE,
    num_classes=NUM_CLASSES,
    sparse_labels=None,
)
 
# Tell Masterful that your test/validation data is using one-hot labels.
test_data_params = masterful.data.learn_data_params(
    dataset=(x_test, y_test),
    task=masterful.enums.Task.CLASSIFICATION,
    image_range=masterful.enums.ImageRange.NEG_ONE_POS_ONE,
    num_classes=NUM_CLASSES,
    sparse_labels=False,
)
 
# Let Masterful meta-learn ideal optimization hyperparameters like
# batch size, learning rate, optimizer, learning rate schedule, and epochs.
# This will speed up training.
optimization_params = masterful.optimization.learn_optimization_params(
    model,
    model_params,
    (x_labeled_train, y_labeled_train),
    labeled_training_data_params,
)
 
# Let Masterful meta-learn ideal regularization hyperparameters. Regularization
# is an important ingredient of SSL. Meta-learning can
# take a while so we'll use a precached set of parameters.
# regularization_params =
#   masterful.regularization.learn_regularization_params(model,
#                                                        model_params,
#                                                        optimization_params,
#                                                        (x_labeled_train, y_labeled_train),
#                                                        labeled_training_data_params)
 
regularization_params = masterful.regularization.parameters.CIFAR10_SMALL
 
# Let Masterful meta-learn ideal SSL hyperparameters.
ssl_params = masterful.ssl.learn_ssl_params(
    (x_labeled_train, y_labeled_train),
    labeled_training_data_params,
    unlabeled_datasets=[((x_unlabeled_train,), unlabeled_training_data_params)],
)

Setup Masterful

では、Masterfulの設定パラメータをいくつか設定します。

MASTERFUL: Learning optimal batch size.
MASTERFUL: Learning optimal initial learning rate for batch size 256.

結果は以下の通りです。

training_report = masterful.training.train(
    model,
    model_params,
    optimization_params,
    regularization_params,
    ssl_params,
    (x_labeled_train, y_labeled_train),
    labeled_training_data_params,
    (x_test, y_test),
    test_data_params,
    unlabeled_datasets=[((x_unlabeled_train,), unlabeled_training_data_params)],
)

トレイン!

さて、SSL技術を使ったトレーニングの準備が整いました! Masterfulのトレーニングエンジンへのエントリポイントであるmasterful.training.trainを呼び出すことにします。

MASTERFUL: Training model with semi-supervised learning enabled.
MASTERFUL: Performing basic dataset analysis.
MASTERFUL: Training model with:
MASTERFUL:  5000 labeled examples.
MASTERFUL:  10000 validation examples.
MASTERFUL:  0 synthetic examples.
MASTERFUL:  20000 unlabeled examples.
MASTERFUL: Training model with learned parameters partridge-boiled-cap in two phases.
MASTERFUL: The first phase is supervised training with the learned parameters.
MASTERFUL: The second phase is semi-supervised training to boost performance.
MASTERFUL: Warming up model for supervised training.
MASTERFUL:  Warming up batch norm statistics (this could take a few minutes).
MASTERFUL:  Warming up training for 500 steps.
100%|██████████| 500/500 [00:47<00:00, 10.59steps/s]
MASTERFUL:  Validating batch norm statistics after warmup for stability (this could take a few minutes).
MASTERFUL: Starting Phase 1: Supervised training until the validation loss stabilizes...
Supervised Training: 100%|██████████| 6300/6300 [02:33<00:00, 41.13steps/s]
MASTERFUL: Starting Phase 2: Semi-supervised training until the validation loss stabilizes...
MASTERFUL: Warming up model for semi-supervised training.
MASTERFUL:  Warming up batch norm statistics (this could take a few minutes).
MASTERFUL:  Warming up training for 500 steps.
100%|██████████| 500/500 [00:23<00:00, 20.85steps/s]
MASTERFUL:  Validating batch norm statistics after warmup for stability (this could take a few minutes).
Semi-Supervised Training: 100%|██████████| 11868/11868 [08:06<00:00, 24.39steps/s]

結果は以下の通りです。

masterful_metrics = model.evaluate(
    x_test, y_test, return_dict=True, verbose=0
)
print(f"Baseline model accuracy: {baseline_metrics['categorical_accuracy']}")
print(f"Masterful model accuracy: {masterful_metrics['categorical_accuracy']}")

結果を分析する

masterful.training.trainに渡したモデルは学習・更新されたので、他の学習済みKerasモデルと同じように評価することができます

Baseline model accuracy: 0.436599999666214
Masterful model accuracy: 0.558899998664856

結果は以下の通りです。

import matplotlib.cm as cm
from matplotlib.colors import Normalize
  
data = (baseline_metrics['categorical_accuracy'], masterful_metrics['categorical_accuracy'])
fig, ax = plt.subplots(1, 1)
    
ax.bar(range(2), data, color=('gray', 'red'))
 
plt.xlabel("Training Method")
plt.ylabel("Accuracy")
 
plt.xticks((0,1), ("baseline", "SSL with Masterful"))
 
plt.show()

結果を可視化する

ご覧の通り、精度率が約0.45から0.56に向上していますね。

もちろん、より厳密な研究では、ベースラインのトレーニングとMasterfulプラットフォーム経由でSSLを使用したトレーニングの間の他の違いを削除することを試みるでしょうし、実行を数回繰り返し、エラーバーとp値を生成することもできます。

とりあえず、この結果を説明するために、グラフとしてプロットしてみましょう。

Sample Image
Masterful Training Method

まとめ

簡単なチュートリアルで、最も高度な学習方法の1つであるSSLを利用して、モデルの精度を向上させることに成功したのです。

その過程で、あなたはラベリングにかかるコストと労力を回避することができました。

SSLは分類だけでなく、あらゆるコンピュータビジョンタスクに適用できます。

このテーマをより深く掘り下げ、SSLが物体検出に使われている様子を見るには、こちらのチュートリアルを参照してください。

タイトルとURLをコピーしました