[改良版]KerasでVAT(Virtual Adversarial Training)を使ってMNISTをやってみる

テクノロジー

遅まきながら、VAT(Virtual Adversarial Training)という学習方法を知ったのですが、Kerasでの実装が見つからなかったので実装してみました。

はじめに

VATは簡単にいうと、「通常の入力X→出力Y」と「なるべく結果が異なるように入力Xに微小なノイズdを入力に加えた入力(X+d)→出力Y'」から「KL-Divergence(Y, Y')」を損失関数に余分に加えて学習をする手法です。

これだけだと何言ってるかわからないと思うので、詳しくは元の論文か、この方の解説をご覧になると良いかと思います。

VATは学習における位置づけとしては「正則化」に近いという話で、DropoutやNoiseを加える代わりになる可能性があります。Dropoutとかのパラメータを調整するのも面倒なので、VATで代用できると嬉しい気がします。

Kerasだとコスト関数や正則化関数に入力Xを使うようにするのが少し厄介なのですが、そこさえなんとかなれば、ChainerやTheanoでの実装があるので移植すればOKです。

先日、KerasでVAT(Virtual Adversarial Training)を使ってMNISTをやってみるを投稿したのですが、もう少しマシっぽい実装ができたので共有します。

Version

実装


前回からの違いは、

というところです。
mnist_with_vat_model.py
# coding: utf8

"""

* VAT: https://arxiv.org/abs/1507.00677
# 参考にしたCode

Original: https://github.com/fchollet/keras/blob/master/examples/mnist_cnn.py

VAT: https://github.com/musyoku/vat/blob/master/vat.py


results example

---------------


finish: use_dropout=False, use_vat=False: score=0.215942835068, accuracy=0.9872

finish: use_dropout=True, use_vat=False: score=0.261140023788, accuracy=0.9845

finish: use_dropout=False, use_vat=True: score=0.240192672965, accuracy=0.9894

finish: use_dropout=True, use_vat=True: score=0.210011005498, accuracy=0.9891

"""

import numpy as np

from functools import reduce

from keras.engine.topology import Input, Container, to_list

from keras.engine.training import Model


np.random.seed(1337) # for reproducibility


from keras.datasets import mnist

from keras.layers import Dense, Dropout, Activation, Flatten

from keras.layers import Convolution2D, MaxPooling2D

from keras.utils import np_utils

from keras import backend as K


SAMPLE_SIZE = 0


batch_size = 128

nb_classes = 10

nb_epoch = 12


# input image dimensions

img_rows, img_cols = 28, 28

# number of convolutional filters to use

nb_filters = 32

# size of pooling area for max pooling

pool_size = (2, 2)

# convolution kernel size

kernel_size = (3, 3)


def main(data, use_dropout, use_vat):

np.random.seed(1337) # for reproducibility


# the data, shuffled and split between train and test sets

(X_train, y_train), (X_test, y_test) = data


if K.image_dim_ordering() == 'th':

X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)

X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)

input_shape = (1, img_rows, img_cols)

else:

X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)

X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)

input_shape = (img_rows, img_cols, 1)


X_train = X_train.astype('float32')

X_test = X_test.astype('float32')

X_train /= 255.

X_test /= 255.


# convert class vectors to binary class matrices

y_train = np_utils.to_categorical(y_train, nb_classes)

y_test = np_utils.to_categorical(y_test, nb_classes)


if SAMPLE_SIZE:

X_train = X_train[:SAMPLE_SIZE]

y_train = y_train[:SAMPLE_SIZE]

X_test = X_test[:SAMPLE_SIZE]

y_test = y_test[:SAMPLE_SIZE]


print("start: use_dropout=%s, use_vat=%s" % (use_dropout, use_vat))

my_model = MyModel(input_shape, use_dropout, use_vat).build()

my_model.training(X_train, y_train, X_test, y_test)


score = my_model.model.evaluate(X_test, y_test, verbose=0)

print("finish: use_dropout=%s, use_vat=%s: score=%s, accuracy=%s" % (use_dropout, use_vat, score[0], score[1]))


class MyModel:

model = None


def __init__(self, input_shape, use_dropout=True, use_vat=True):

self.input_shape = input_shape

self.use_dropout = use_dropout

self.use_vat = use_vat


def build(self):

input_layer = Input(self.input_shape)

output_layer = self.core_data_flow(input_layer)

if self.use_vat:

self.model = VATModel(input_layer, output_layer).setup_vat_loss()

else:

self.model = Model(input_layer, output_layer)

return self


def core_data_flow(self, input_layer):

x = Convolution2D(nb_filters, kernel_size[0], kernel_size[1], border_mode='valid')(input_layer)

x = Activation('relu')(x)

x = Convolution2D(nb_filters, kernel_size[0], kernel_size[1])(x)

x = Activation('relu')(x)

x = MaxPooling2D(pool_size=pool_size)(x)

if self.use_dropout:

x = Dropout(0.25)(x)


x = Flatten()(x)

x = Dense(128, activation="relu")(x)

if self.use_dropout:

x = Dropout(0.5)(x)

x = Dense(nb_classes, activation='softmax')(x)

return x


def training(self, X_train, y_train, X_test, y_test):

self.model.compile(loss=K.categorical_crossentropy, optimizer='adadelta', metrics=['accuracy'])

np.random.seed(1337) # for reproducibility

self.model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=nb_epoch,

verbose=1, validation_data=(X_test, y_test))


class VATModel(Model):

_vat_loss = None


def setup_vat_loss(self, eps=1, xi=10, ip=1):

self._vat_loss = self.vat_loss(eps, xi, ip)

return self


@property

def losses(self):

losses = super(self.__class__, self).losses

if self._vat_loss:

losses += [self._vat_loss]

return losses


def vat_loss(self, eps, xi, ip):

normal_outputs = [K.stop_gradient(x) for x in to_list(self.outputs)]

d_list = [K.random_normal(x.shape) for x in self.inputs]


for _ in range(ip):

new_inputs = [x + self.normalize_vector(d)*xi for (x, d) in zip(self.inputs, d_list)]

new_outputs = to_list(self.call(new_inputs))

klds = [K.sum(self.kld(normal, new)) for normal, new in zip(normal_outputs, new_outputs)]

kld = reduce(lambda t, x: t+x, klds, 0)

d_list = [K.stop_gradient(d) for d in K.gradients(kld, d_list)]


new_inputs = [x + self.normalize_vector(d) * eps for (x, d) in zip(self.inputs, d_list)]

y_perturbations = to_list(self.call(new_inputs))

klds = [K.mean(self.kld(normal, new)) for normal, new in zip(normal_outputs, y_perturbations)]

kld = reduce(lambda t, x: t + x, klds, 0)

return kld


@staticmethod

def normalize_vector(x):

z = K.sum(K.batch_flatten(K.square(x)), axis=1)

while K.ndim(z) < K.ndim(x):

z = K.expand_dims(z, dim=-1)

return x / (K.sqrt(z) + K.epsilon())


@staticmethod

def kld(p, q):

v = p * (K.log(p + K.epsilon()) - K.log(q + K.epsilon()))

return K.sum(K.batch_flatten(v), axis=1, keepdims=True)


data = mnist.load_data()

main(data, use_dropout=False, use_vat=False)

main(data, use_dropout=True, use_vat=False)

main(data, use_dropout=False, use_vat=True)

main(data, use_dropout=True, use_vat=True)

実験結果


前回と同じように実験してみました。

Dropout VAT Accuracy 1 epochの時間
使わない 使わない 98.72% 8秒
使う 使わない 98.45% 8秒
使わない 使う 98.94% 18秒
使う 使う 98.91% 18秒
だいたい同じような結果になりました。

さいごに


どっちみちPlaceholderに対して計算するんだから、これでも良いはずだと思って色々試行錯誤していたら上手くいった気がします。なかなかTensorの流れをちゃんとイメージするのが難しいですね。

本当はContainerにこの機能を付けようとしたんですが(教師なしでも使えるのだから)、現在のKerasの実装だとContainerが余分なLossをModelのtotal_lossに足し込む仕組みがわからず断念。Layerだと複数入れ込めますが、遠くのfunction(input) -> outputを別途渡してあげないとVATは計算できないのであまり嬉しくない。まあ、前回よりマシになったので良しにします。

まずはお気軽にお問い合わせください

「Sprocket」の費用や導入スケジュール、また、御社の顧客体験の向上やコンバージョンの最適化、Web接客ツールの比較検討においてご不明な点がございましたら、お気軽にお問い合わせください。(無料)

03-6303-4123

受付時間:平日10時~12時/13時~17時