従来、分類器のようなコンピュータビジョンモデルの学習には、ラベル付けされたデータが必要であった。
学習データの各例は、画像と、その画像を説明する人間が作成したラベルのペアである必要がありました。
近年、新しいSSL技術は、Imagenetのような古典的な課題に対して、コンピュータビジョンで最も正確なモデルを提供することができるようになりました。
半教師付き学習(SSL)は、ラベル付けされたデータとラベル付けされていないデータの両方からモデルを学習させるものです。
非ラベル化データは、ラベルのない画像のみから構成される。
SSLが優れているのは、通常、ラベル付きデータよりもラベルなしデータの方が多いからです。
特に、モデルを実運用に配備した後は、ラベルなしデータの方が多くなります。
また、SSLはラベリングにかかる時間、コスト、労力を削減します。
しかし、ラベルのない画像からモデルはどのように学習するのでしょうか?重要な洞察は、画像そのものが情報を持っているということです。
SSLの魔法は、その構造に基づいて類似している画像を自動的にクラスタリングすることにより、ラベルのないデータから情報を抽出できることであり、このクラスタリングはモデルが学習するための追加情報を提供します。
この記事では、Google Colabに含まれるmatplotlib、numpy、TensorFlowなどのいくつかの一般的なPythonライブラリを使用します。
もしそれらをインストールする必要があるなら、通常Jupyterノートブック内で !pip install --upgrade pip; pip install matplotlib numpy tensorflow
またはコマンドラインから pip install --upgrade pip; pip install matplotlib numpy tensorflow
(exclamation pointなし) を実行することが可能です。
Google Colabを使用している場合は、ランタイムタイプをGPUに変更することを確認してください。
この記事では、CIFAR-10データセットで分類器を学習させましょう。
これは自然画像からなる古典的な研究データセットです。
ロードして見てみましょう。
CIFAR-10 のクラスは、カエル、ボート、車、トラック、鹿、馬、鳥、猫、犬、そして飛行機です。
import matplotlib.pyplot as plt
def plot_images(images):
"""Simple utility to render images."""
# Visualize the data.
_, axarr = plt.subplots( 5 , 5 , figsize = ( 15 , 15 ))
for row in range ( 5 ):
for col in range ( 5 ):
image = images[row * 5 + col]
axarr[row, col].imshow(image)
import tensorflow as tf
NUM_CLASSES = 10
# Load the data using the Keras Datasets API. (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
plot_images(x_test) |
def get_model():
return tf.keras.applications.MobileNet(input_shape = ( 32 , 32 , 3 ),
weights = None ,
classes = NUM_CLASSES,
classifier_activation = None )
model = get_model()
|
この記事もチェック:Pythonでデータセットから別のデータセットへピボットテーブルを作成する
モデルの作成
一般に、モデル・アーキテクチャは既製のものを使用するのがよいでしょう。
そうすれば、モデルアーキテクチャの設計に頭を悩ませる手間が省けます。
モデルサイズの一般的なルールは、データを処理するのに十分な大きさで、かつ推論時に遅くならない大きさのモデルを選択することです。
CIFAR-10のような非常に小さなデータセットには、非常に小さなモデルを使用します。
画像サイズの大きなデータセットには、Efficient Netファミリーが適しています。
def normalize_data(x_train, y_train, x_test, y_test):
"""Utility to normalize the data into standard formats."""
# Update the pixel range to [-1,1], which is expected by the model architecture.
x_train = x = tf.keras.applications.mobilenet.preprocess_input(x_train)
x_test = x = tf.keras.applications.mobilenet.preprocess_input(x_test)
# Convert to one-hot labels.
y_train = tf.keras.utils.to_categorical(y_train, NUM_CLASSES)
y_test = tf.keras.utils.to_categorical(y_test, NUM_CLASSES)
return x_train, y_train, x_test, y_test
x_train, y_train, x_test, y_test =
normalize_data(x_train, y_train, x_test, y_test)
|
この記事もチェック:PyTorchで音声や画像のカスタムデータセットを自作する方法
データを用意する
それでは、10種類の物体を表す0から9の整数であるラベルを、[1,0,0,0,0,0,0,0]や[0,0,0,0,0,1]という一発のベクトルに変換してデータを用意しましょう。
また、画像のピクセルをモデルアーキテクチャが期待する範囲、すなわち[-1, 1]に更新することにします。
import numpy as np
def prepare_data(x_train, y_train, num_labeled_examples, num_unlabeled_examples):
"""Returns labeled and unlabeled datasets."""
num_examples = x_train.size
assert num_labeled_examples + num_unlabeled_examples < = num_examples
# Generate some random indices.
dataset_size = len (x_train)
indices = np.array( range (dataset_size))
generator = np.random.default_rng(seed = 0 )
generator.shuffle(indices)
# Split the indices into two sets: one for labeled, one for unlabeled.
labeled_train_indices = indices[:num_labeled_examples]
unlabeled_train_indices = indices[num_labeled_examples : num_labeled_examples + num_unlabeled_examples]
x_labeled_train = x_train[labeled_train_indices]
y_labeled_train = y_train[labeled_train_indices]
x_unlabeled_train = x_train[unlabeled_train_indices]
# Since this is unlabeled, we won't need a y_labeled_data.
return x_labeled_train, y_labeled_train, x_unlabeled_train
NUM_LABELED = 5000
NUM_UNLABELED = 20000
x_labeled_train, y_labeled_train, x_unlabeled_train =
prepare_data(x_train,
y_train,
num_labeled_examples = NUM_LABELED,
num_unlabeled_examples = NUM_UNLABELED)
del x_train, y_train
|
このデータセットには50,000の例が含まれている。
そのうち5,000例をラベル付き画像とし、20,000例をラベルなし画像として使用することにします。
model. compile (
optimizer = tf.keras.optimizers.Adam(),
loss = tf.keras.losses.CategoricalCrossentropy(from_logits = True ),
metrics = [tf.keras.metrics.CategoricalAccuracy()],
) # Setup Keras augmentation. datagen = tf.keras.preprocessing.image.ImageDataGenerator(
featurewise_center = False ,
featurewise_std_normalization = False ,
horizontal_flip = True )
datagen.fit(x_labeled_train) batch_size = 64
epochs = 30
model.fit( x = datagen.flow(x_labeled_train, y_labeled_train, batch_size = batch_size),
shuffle = True ,
validation_data = (x_test, y_test),
batch_size = batch_size,
epochs = epochs,
) baseline_metrics = model.evaluate(x = x_test, y = y_test, return_dict = True )
print ('')
print (f "Baseline model accuracy: {baseline_metrics['categorical_accuracy']}" )
|
ベースライン・トレーニング
SSLによる性能向上を測定するために、まず、SSLを使用しない標準的なトレーニングループでモデルの性能を測定してみましょう。
いくつかの基本的なデータ補強を行った標準的なトレーニングループをセットアップしましょう。
データ増強は正則化の一種で、オーバーフィッティングを防ぎ、モデルが見たことのないデータに対してよりよく汎化できるようにします。
以下のハイパーパラメータ値(学習率、エポック、バッチサイズなど)は、一般的なデフォルト値と手動でチューニングした値の組み合わせです。
その結果、約45%の精度のモデルが出来上がりました。
(学習精度ではなく、検証精度を読むことを忘れないでください)。
次の課題は、SSLを使用してモデルの精度を向上させることができるかどうかを見極めることです。
Epoch 1 /30
79 /79 [==============================] - 4s 23ms /step - loss: 2.4214 - categorical_accuracy: 0.1578 - val_loss: 2.3047 - val_categorical_accuracy: 0.1000
Epoch 2 /30
79 /79 [==============================] - 1s 16ms /step - loss: 2.0831 - categorical_accuracy: 0.2196 - val_loss: 2.3063 - val_categorical_accuracy: 0.1000
Epoch 3 /30
79 /79 [==============================] - 1s 16ms /step - loss: 1.9363 - categorical_accuracy: 0.2852 - val_loss: 2.3323 - val_categorical_accuracy: 0.1000
Epoch 4 /30
79 /79 [==============================] - 1s 16ms /step - loss: 1.8324 - categorical_accuracy: 0.3174 - val_loss: 2.3496 - val_categorical_accuracy: 0.1000
Epoch 5 /30
79 /79 [==============================] - 1s 16ms /step - loss: 1.8155 - categorical_accuracy: 0.3438 - val_loss: 2.3339 - val_categorical_accuracy: 0.1000
Epoch 6 /30
79 /79 [==============================] - 1s 15ms /step - loss: 1.6477 - categorical_accuracy: 0.3886 - val_loss: 2.3606 - val_categorical_accuracy: 0.1000
Epoch 7 /30
79 /79 [==============================] - 1s 15ms /step - loss: 1.6120 - categorical_accuracy: 0.4100 - val_loss: 2.3585 - val_categorical_accuracy: 0.1000
Epoch 8 /30
79 /79 [==============================] - 1s 16ms /step - loss: 1.5884 - categorical_accuracy: 0.4220 - val_loss: 2.1796 - val_categorical_accuracy: 0.2519
Epoch 9 /30
79 /79 [==============================] - 1s 18ms /step - loss: 1.5477 - categorical_accuracy: 0.4310 - val_loss: 1.8913 - val_categorical_accuracy: 0.3145
Epoch 10 /30
79 /79 [==============================] - 1s 15ms /step - loss: 1.4328 - categorical_accuracy: 0.4746 - val_loss: 1.7082 - val_categorical_accuracy: 0.3696
Epoch 11 /30
79 /79 [==============================] - 1s 16ms /step - loss: 1.4328 - categorical_accuracy: 0.4796 - val_loss: 1.7679 - val_categorical_accuracy: 0.3811
Epoch 12 /30
79 /79 [==============================] - 2s 20ms /step - loss: 1.3962 - categorical_accuracy: 0.5020 - val_loss: 1.8994 - val_categorical_accuracy: 0.3690
Epoch 13 /30
79 /79 [==============================] - 1s 16ms /step - loss: 1.3271 - categorical_accuracy: 0.5156 - val_loss: 2.0416 - val_categorical_accuracy: 0.3688
Epoch 14 /30
79 /79 [==============================] - 1s 17ms /step - loss: 1.2711 - categorical_accuracy: 0.5374 - val_loss: 1.9231 - val_categorical_accuracy: 0.3848
Epoch 15 /30
79 /79 [==============================] - 1s 15ms /step - loss: 1.2312 - categorical_accuracy: 0.5624 - val_loss: 1.9006 - val_categorical_accuracy: 0.3961
Epoch 16 /30
79 /79 [==============================] - 1s 19ms /step - loss: 1.2048 - categorical_accuracy: 0.5720 - val_loss: 2.0102 - val_categorical_accuracy: 0.4102
Epoch 17 /30
79 /79 [==============================] - 1s 16ms /step - loss: 1.1365 - categorical_accuracy: 0.6000 - val_loss: 2.1400 - val_categorical_accuracy: 0.3672
Epoch 18 /30
79 /79 [==============================] - 1s 18ms /step - loss: 1.1992 - categorical_accuracy: 0.5840 - val_loss: 2.1206 - val_categorical_accuracy: 0.3933
Epoch 19 /30
79 /79 [==============================] - 2s 25ms /step - loss: 1.1438 - categorical_accuracy: 0.6012 - val_loss: 2.4035 - val_categorical_accuracy: 0.4014
Epoch 20 /30
79 /79 [==============================] - 2s 24ms /step - loss: 1.1211 - categorical_accuracy: 0.6018 - val_loss: 2.0224 - val_categorical_accuracy: 0.4010
Epoch 21 /30
79 /79 [==============================] - 2s 21ms /step - loss: 1.0425 - categorical_accuracy: 0.6358 - val_loss: 2.2100 - val_categorical_accuracy: 0.3911
Epoch 22 /30
79 /79 [==============================] - 1s 16ms /step - loss: 1.1177 - categorical_accuracy: 0.6116 - val_loss: 1.9892 - val_categorical_accuracy: 0.4285
Epoch 23 /30
79 /79 [==============================] - 1s 19ms /step - loss: 1.0236 - categorical_accuracy: 0.6412 - val_loss: 2.1216 - val_categorical_accuracy: 0.4211
Epoch 24 /30
79 /79 [==============================] - 1s 18ms /step - loss: 0.9487 - categorical_accuracy: 0.6714 - val_loss: 2.0135 - val_categorical_accuracy: 0.4307
Epoch 25 /30
79 /79 [==============================] - 1s 16ms /step - loss: 1.1877 - categorical_accuracy: 0.5876 - val_loss: 2.3732 - val_categorical_accuracy: 0.3923
Epoch 26 /30
79 /79 [==============================] - 2s 20ms /step - loss: 1.0639 - categorical_accuracy: 0.6288 - val_loss: 1.9291 - val_categorical_accuracy: 0.4291
Epoch 27 /30
79 /79 [==============================] - 2s 19ms /step - loss: 0.9243 - categorical_accuracy: 0.6882 - val_loss: 1.8552 - val_categorical_accuracy: 0.4343
Epoch 28 /30
79 /79 [==============================] - 1s 15ms /step - loss: 0.9784 - categorical_accuracy: 0.6656 - val_loss: 2.0175 - val_categorical_accuracy: 0.4386
Epoch 29 /30
79 /79 [==============================] - 1s 17ms /step - loss: 0.9316 - categorical_accuracy: 0.6800 - val_loss: 1.9916 - val_categorical_accuracy: 0.4305
Epoch 30 /30
79 /79 [==============================] - 1s 17ms /step - loss: 0.8816 - categorical_accuracy: 0.7054 - val_loss: 2.0281 - val_categorical_accuracy: 0.4366
313 /313 [==============================] - 1s 3ms /step - loss: 2.0280 - categorical_accuracy: 0.4366
Baseline model accuracy: 0.436599999666214 |
結果は、以下の通りになります。
!pip install - - upgrade pip
!pip install masterful import masterful
masterful = masterful.register()
|
SSLを使ったトレーニング
では、学習データにラベルのないデータを追加することで、モデルの精度を向上させることができるか見てみましょう。
今回の分類器のようなコンピュータビジョンモデルにSSLを実装するプラットフォームであるMasterfulを使用します。
Masterfulをインストールしましょう。
Google Colabでは、ノートブックセルからpip installすることができます。
また、コマンドラインでもインストールできます。
詳しくは、Masterfulのインストールガイドをご覧ください。
Loaded Masterful version 0.4.1. This software is distributed free of
charge for personal projects and evaluation purposes.
See http: //www .masterfulai.com /personal-and-evaluation-agreement for details.
Sign up in the next 45 days at https: //www .masterfulai.com /get-it-now
to continue using Masterful.
|
結果は以下の通りです。
# Start fresh with a new model tf.keras.backend.clear_session() model = get_model()
# Tell Masterful that your model is performing a classification task # with 10 labels and that the image pixel range is # [-1,1]. Also, the model outputs logits rather than a softmax activation. model_params = masterful.architecture.learn_architecture_params(
model = model,
task = masterful.enums.Task.CLASSIFICATION,
input_range = masterful.enums.ImageRange.NEG_ONE_POS_ONE,
num_classes = NUM_CLASSES,
prediction_logits = True ,
) # Tell Masterful that your labeled training data is using one-hot labels. labeled_training_data_params = masterful.data.learn_data_params(
dataset = (x_labeled_train, y_labeled_train),
task = masterful.enums.Task.CLASSIFICATION,
image_range = masterful.enums.ImageRange.NEG_ONE_POS_ONE,
num_classes = NUM_CLASSES,
sparse_labels = False ,
) unlabeled_training_data_params = masterful.data.learn_data_params(
dataset = (x_unlabeled_train,),
task = masterful.enums.Task.CLASSIFICATION,
image_range = masterful.enums.ImageRange.NEG_ONE_POS_ONE,
num_classes = NUM_CLASSES,
sparse_labels = None ,
) # Tell Masterful that your test/validation data is using one-hot labels. test_data_params = masterful.data.learn_data_params(
dataset = (x_test, y_test),
task = masterful.enums.Task.CLASSIFICATION,
image_range = masterful.enums.ImageRange.NEG_ONE_POS_ONE,
num_classes = NUM_CLASSES,
sparse_labels = False ,
) # Let Masterful meta-learn ideal optimization hyperparameters like # batch size, learning rate, optimizer, learning rate schedule, and epochs. # This will speed up training. optimization_params = masterful.optimization.learn_optimization_params(
model,
model_params,
(x_labeled_train, y_labeled_train),
labeled_training_data_params,
) # Let Masterful meta-learn ideal regularization hyperparameters. Regularization # is an important ingredient of SSL. Meta-learning can # take a while so we'll use a precached set of parameters. # regularization_params = # masterful.regularization.learn_regularization_params(model, # model_params, # optimization_params, # (x_labeled_train, y_labeled_train), # labeled_training_data_params) regularization_params = masterful.regularization.parameters.CIFAR10_SMALL
# Let Masterful meta-learn ideal SSL hyperparameters. ssl_params = masterful.ssl.learn_ssl_params(
(x_labeled_train, y_labeled_train),
labeled_training_data_params,
unlabeled_datasets = [((x_unlabeled_train,), unlabeled_training_data_params)],
) |
この記事もチェック:Pythonとsklearnを使って機械学習パイプラインを実装する方法
Setup Masterful
では、Masterfulの設定パラメータをいくつか設定します。
MASTERFUL: Learning optimal batch size. MASTERFUL: Learning optimal initial learning rate for batch size 256.
|
結果は以下の通りです。
training_report = masterful.training.train(
model,
model_params,
optimization_params,
regularization_params,
ssl_params,
(x_labeled_train, y_labeled_train),
labeled_training_data_params,
(x_test, y_test),
test_data_params,
unlabeled_datasets = [((x_unlabeled_train,), unlabeled_training_data_params)],
) |
トレイン!
さて、SSL技術を使ったトレーニングの準備が整いました! Masterfulのトレーニングエンジンへのエントリポイントであるmasterful.training.trainを呼び出すことにします。
MASTERFUL: Training model with semi-supervised learning enabled. MASTERFUL: Performing basic dataset analysis. MASTERFUL: Training model with: MASTERFUL: 5000 labeled examples. MASTERFUL: 10000 validation examples. MASTERFUL: 0 synthetic examples. MASTERFUL: 20000 unlabeled examples. MASTERFUL: Training model with learned parameters partridge-boiled-cap in two phases.
MASTERFUL: The first phase is supervised training with the learned parameters. MASTERFUL: The second phase is semi-supervised training to boost performance. MASTERFUL: Warming up model for supervised training.
MASTERFUL: Warming up batch norm statistics (this could take a few minutes). MASTERFUL: Warming up training for 500 steps.
100%|██████████| 500 /500 [00:47<00:00, 10.59steps /s ]
MASTERFUL: Validating batch norm statistics after warmup for stability (this could take a few minutes).
MASTERFUL: Starting Phase 1: Supervised training until the validation loss stabilizes...
Supervised Training: 100%|██████████| 6300 /6300 [02:33<00:00, 41.13steps /s ]
MASTERFUL: Starting Phase 2: Semi-supervised training until the validation loss stabilizes...
MASTERFUL: Warming up model for semi-supervised training.
MASTERFUL: Warming up batch norm statistics (this could take a few minutes). MASTERFUL: Warming up training for 500 steps.
100%|██████████| 500 /500 [00:23<00:00, 20.85steps /s ]
MASTERFUL: Validating batch norm statistics after warmup for stability (this could take a few minutes).
Semi-Supervised Training: 100%|██████████| 11868 /11868 [08:06<00:00, 24.39steps /s ]
|
結果は以下の通りです。
masterful_metrics = model.evaluate(
x_test, y_test, return_dict = True , verbose = 0
) print (f "Baseline model accuracy: {baseline_metrics['categorical_accuracy']}" )
print (f "Masterful model accuracy: {masterful_metrics['categorical_accuracy']}" )
|
結果を分析する
masterful.training.trainに渡したモデルは学習・更新されたので、他の学習済みKerasモデルと同じように評価することができます。
Baseline model accuracy: 0.436599999666214 Masterful model accuracy: 0.558899998664856 |
結果は以下の通りです。
import matplotlib.cm as cm
from matplotlib.colors import Normalize
data = (baseline_metrics[ 'categorical_accuracy' ], masterful_metrics[ 'categorical_accuracy' ])
fig, ax = plt.subplots( 1 , 1 )
ax.bar( range ( 2 ), data, color = ( 'gray' , 'red' ))
plt.xlabel( "Training Method" )
plt.ylabel( "Accuracy" )
plt.xticks(( 0 , 1 ), ( "baseline" , "SSL with Masterful" ))
plt.show() |
結果を可視化する
ご覧の通り、精度率が約0.45から0.56に向上していますね。
もちろん、より厳密な研究では、ベースラインのトレーニングとMasterfulプラットフォーム経由でSSLを使用したトレーニングの間の他の違いを削除することを試みるでしょうし、実行を数回繰り返し、エラーバーとp値を生成することもできます。
とりあえず、この結果を説明するために、グラフとしてプロットしてみましょう。
まとめ
簡単なチュートリアルで、最も高度な学習方法の1つであるSSLを利用して、モデルの精度を向上させることに成功したのです。
その過程で、あなたはラベリングにかかるコストと労力を回避することができました。
SSLは分類だけでなく、あらゆるコンピュータビジョンタスクに適用できます。
このテーマをより深く掘り下げ、SSLが物体検出に使われている様子を見るには、こちらのチュートリアルを参照してください。