こんにちは.エクサウィザーズでインターンをしている川畑です.
視覚によるコミュニケーションというのは人々が相手に何らかのアイデアを伝える際に鍵となります.私たちは小さい頃から物体を描く力を養ってきており,時には感情までもたった複数の線で表現することも可能です.こうした単純な絵というのは,身の回りのものを写真のように捉え忠実に再現したものではなく,どのようにして人間が物体の特徴を認識しそれらを再現するか,ということを教えてくれます.
そこで今回はSketch-RNNと呼ばれるRecurrent Neural Networkモデルでのスケッチの自動生成に取り組んでみました.
このモデルは人間がするのと同じように抽象的な概念を一般化し,スケッチを生成することを目的としたものです.このモデルに関しては今の所具体的なアプリケーションが存在するというわけではなく,機械学習がどのようにクリエイティブな分野で活用できるか,という一例を提案したものになります.ソースコードはこちらからダウンロードできます.
[1] Google AI Blog: Teaching Machines to Draw
データセット
今回は論文でも用いられていたQuick, Draw!のデータセットを使用しました.これは20秒以内であるお題の絵を描くゲームで,データセットとしてネコやブタ,バスなど数百個のクラスのデータを公開しています.各クラスのデータは,訓練,検証,テストとしてそれぞれ70000,2500,2500のデータに分かれています.
Sketch-RNNモデルを学習させるにあたり,データのフォーマットとしては各点のデータが5つの要素からなるベクトルを使用しました.最初の二つの要素は,前の点からのの変位,残りの三つの要素はone-hotベクトルとなっています.一つ目のペン状態はペンが現在紙に接していて,次の点まで線が描かれることを示しており,二つ目のペン状態はそこでペンが紙から離れ,次の点まで線は引かれないことを示しており,三つ目のペン状態はスケッチが終了することを示しています.
Sketch-RNNモデル
それではモデルについて解説していきたいと思います.
[2] A Neural Representation of Sketch Drawings
モデルの大枠はSequence-to-Sequence Variational Autoencoder (VAE) でできています.まず,エンコーダーであるbidirectional RNN (今回は単純なLSTMを使用)にスケッチのシーケンスを入力し,出力として潜在ベクトルを得ます.具体的には,双方向のRNNから得られた隠れ状態を連結し,全結合層によって連結された隠れ状態からおよびを得ます.に関しては,負にならないようにの操作を加えます.そしてさらにガウス分布と組み合わせることで最終的に潜在ベクトルを計算することができます.式で書くと以下のようになります.
潜在ベクトルが得られたら,次はデコーダーです.隠れ状態の初期値には潜在ベクトルに関数を掛けたものを用います.
各ステップでのデコーダーの入力には前ステップでのストロークと潜在ベクトルを連結したものを与え,にはスケッチの開始を意味するを与えるようにします.各ステップでの出力は次のデータ点の確率分布に関するパラメータを返します.Sketch-RNNでは個の正規分布からなるGaussian mixture model (GMM)によりを,カテゴリカル分布により真値であるをモデル化します.
上の式においては2変量正規分布を表しており,5つのパラメータで構成されます.は平均,は標準偏差,はの相関係数です.また,は長さのベクトルであり,Gaussian mixture modelの各正規分布の重みに相当します.ですので,出力のサイズは3つのロジットも含めて合計でとなります.
ここで,出力ベクトルは次のように分解されます.
標準偏差,相関係数に関してはそれぞれ負にならないように,また-1から1の値を取るようにとによる操作を施します.
カテゴリカル分布に関してはSoftmax関数を適応させ,全ての値が0〜1に収まるようにします.
損失関数
損失関数はVAEと同じでReconstruction Loss, とKullback-Leibler (KL) Divergence Loss, の二つの項からなります.
- Reconstruction Loss,
Reconstruction Loss項は訓練データを説明する確率分布の対数尤度を表しており,これを最大化するように学習します.そして,Reconstruction Loss, はさらに座標に関する項とペン状態に関する項からなっており,それぞれ
以下のように書くことができます.
- Kullback-Leibler (KL) Divergence Loss,
KL Divergence Loss項は潜在ベクトルが標準正規分布からどれだけ離れているかを表しており,これを最小化するように学習することになります.
実際に最適化する損失関数にはとを重み付けして足し合わせたものを用います.
それぞれの損失項にはトレードオフの関係があり,の時にはモデルは純粋なオートエンコーダーに近づき,より良いReconstruction Lossを得ることができる一方で,潜在空間における事前分布の強化を犠牲にすることになります.
サンプリング
モデルを学習させた後はいよいよスケッチの生成です.サンプリング過程では各ステップごとにGMMとカテゴリカル分布のパラメータを得,そのステップでの出力を得ます.訓練過程とは異なり,サンプリング過程では出力を次のステップの入力とし,このサンプリングステップをとなるかステップ数がとなるまで繰り返していきます.最終的に出力を得る際にはdeterministicに確率密度関数で最も確率の高い点を選ぶのではなく,下記のようにtemperatureパラメータを導入することで出力のランダムさを調節できるようにしています.は0〜1の値を取り,の時にはモデルはdeterministicになります.
コード
モデルに関する部分のコードを示します.Githubに論文著者によるオリジナルのコードもありますが,オリジナルではtensorflowで書かれていたものをkerasで書き換えました.全コードはこちらからダウンロードできます.
def get_mixture_coef(output, n_out):
"""Returns the tf slices containing mdn dist params."""
z = output
z = tf.reshape(z, [-1, n_out])
z_pen_logits = z[:, 0:3]
z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = tf.split(z[:, 3:], 6, 1)
z_pi = tf.nn.softmax(z_pi)
z_pen = tf.nn.softmax(z_pen_logits)
z_sigma1 = K.exp(z_sigma1)
z_sigma2 = K.exp(z_sigma2)
z_corr = tf.tanh(z_corr)
r = [z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen, z_pen_logits]
return r
class SketchRNN():
"""SketchRNN model definition."""
def __init__(self, hps):
self.hps = hps
self.build_model(hps)
def build_model(self, hps):
encoder_inputs = Input(shape=(hps.max_seq_len, 5), name='encoder_input')
encoder_lstm = LSTM(hps.enc_rnn_size,
use_bias=True,
recurrent_initializer='orthogonal',
bias_initializer='zeros',
recurrent_dropout=1.0-hps.recurrent_dropout_prob,
return_sequences=True,
return_state=True)
bidirectional = Bidirectional(encoder_lstm)
(unused_outputs,
last_h_fw, unused_c_fw,
last_h_bw, unused_c_bw) = bidirectional(encoder_inputs)
last_h = concatenate([last_h_fw, last_h_bw], 1)
normal_init = RandomNormal(stddev=0.001)
self.z_mean = Dense(hps.z_size,
activation='linear',
use_bias=True,
kernel_initializer=normal_init,
bias_initializer='zeros')(last_h)
self.z_presig = Dense(hps.z_size,
activation='linear',
use_bias=True,
kernel_initializer=normal_init,
bias_initializer='zeros')(last_h)
def sampling(args):
z_mean, z_presig = args
self.sigma = K.exp(0.5 * z_presig)
batch = K.shape(z_mean)[0]
dim = K.int_shape(z_mean)[1]
epsilon = K.random_normal((batch, dim), 0.0, 1.0)
batch_z = z_mean + self.sigma * epsilon
return batch_z
self.batch_z = Lambda(sampling,
output_shape=(hps.z_size,))([self.z_mean, self.z_presig])
self.encoder = Model(
encoder_inputs,
[self.z_mean, self.z_presig, self.batch_z], name='encoder')
self.n_out = (3 + hps.num_mixture * 6)
decoder_inputs = Input(shape=(hps.max_seq_len, 5), name='decoder_input')
overlay_x = RepeatVector(hps.max_seq_len)(self.batch_z)
actual_input_x = concatenate([decoder_inputs, overlay_x], 2)
self.initial_state_layer = Dense(hps.dec_rnn_size * 2,
activation='tanh',
use_bias=True,
kernel_initializer=normal_init)
initial_state = self.initial_state_layer(self.batch_z)
initial_h, initial_c = tf.split(initial_state, 2, 1)
self.decoder_lstm = LSTM(hps.dec_rnn_size,
use_bias=True,
recurrent_initializer='orthogonal',
bias_initializer='zeros',
recurrent_dropout=1.0-hps.recurrent_dropout_prob,
return_sequences=True,
return_state=True
)
output, last_h, last_c = self.decoder_lstm(
actual_input_x, initial_state=[initial_h, initial_c])
self.output_layer = Dense(self.n_out, activation='linear', use_bias=True)
output = self.output_layer(output)
last_state = [last_h, last_c]
self.final_state = last_state
self.sketch_rnn_model = Model(
[encoder_inputs, decoder_inputs],
output,
name='sketch_rnn')
def vae_loss(self, inputs, outputs):
kl_loss = 1 + self.z_presig - K.square(self.z_mean) - K.exp(self.z_presig)
self.kl_loss = -0.5 * K.mean(K.sum(kl_loss, axis=-1))
self.kl_loss = K.maximum(self.kl_loss, K.constant(self.hps.kl_tolerance))
def tf_2d_normal(x1, x2, mu1, mu2, s1, s2, rho):
"""Returns result of eq # 24 of http://arxiv.org/abs/1308.0850."""
norm1 = subtract([x1, mu1])
norm2 = subtract([x2, mu2])
s1s2 = multiply([s1, s2])
z = (K.square(tf.divide(norm1, s1)) + K.square(tf.divide(norm2, s2)) -
2 * tf.divide(multiply([rho, multiply([norm1, norm2])]), s1s2))
neg_rho = 1 - K.square(rho)
result = K.exp(tf.divide(-z, 2 * neg_rho))
denom = 2 * np.pi * multiply([s1s2, K.sqrt(neg_rho)])
result = tf.divide(result, denom)
return result
def get_lossfunc(z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr,
z_pen_logits, x1_data, x2_data, pen_data):
"""Returns a loss fn based on eq #26 of http://arxiv.org/abs/1308.0850."""
result0 = tf_2d_normal(x1_data, x2_data, z_mu1, z_mu2, z_sigma1, z_sigma2,
z_corr)
epsilon = 1e-6
result1 = multiply([result0, z_pi])
result1 = K.sum(result1, 1, keepdims=True)
result1 = -K.log(result1 + epsilon)
fs = 1.0 - pen_data[:, 2]
fs = tf.reshape(fs, [-1, 1])
result1 = multiply([result1, fs])
result2 = tf.nn.softmax_cross_entropy_with_logits_v2(
labels=pen_data, logits=z_pen_logits)
result2 = tf.reshape(result2, [-1, 1])
result2 = multiply([result2, fs])
result = result1 + result2
return result
target = tf.reshape(inputs, [-1, 5])
[x1_data, x2_data, eos_data, eoc_data, cont_data] = tf.split(target, 5, 1)
pen_data = concatenate([eos_data, eoc_data, cont_data], 1)
out = get_mixture_coef(outputs, self.n_out)
[o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, o_pen_logits] = out
lossfunc = get_lossfunc(o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr,
o_pen_logits, x1_data, x2_data, pen_data)
self.r_loss = tf.reduce_mean(lossfunc)
kl_weight = self.hps.kl_weight_start
self.loss = self.r_loss + self.kl_loss * kl_weight
return self.loss
def model_compile(self, model):
adam = Adam(lr=self.hps.learning_rate, clipvalue=self.hps.grad_clip)
model.compile(loss=self.vae_loss, optimizer=adam)
以下は学習後のモデルからスケッチを生成するサンプリングのコードです.
def sample(model, hps, weights, seq_len=250, temperature=1.0,
greedy_mode=False, z=None):
"""Samples a sequence from a pre-trained model."""
def adjust_temp(pi_pdf, temp):
pi_pdf = np.log(pi_pdf) / temp
pi_pdf -= pi_pdf.max()
pi_pdf = np.exp(pi_pdf)
pi_pdf /=pi_pdf.sum()
return pi_pdf
def get_pi_idx(x, pdf, temp=1.0, greedy=False):
"""Samples from a pdf, optionally greedily."""
if greedy:
return np.argmax(pdf)
pdf = adjust_temp(np.copy(pdf), temp)
accumulate = 0
for i in range(0, pdf.size):
accumulate += pdf[i]
if accumulate >= x:
return i
tf.logging.info('Error with smpling ensemble.')
return -1
def sample_gaussian_2d(mu1, mu2, s1, s2, rho, temp=1.0, greedy=False):
if greedy:
return mu1, mu2
mean = [mu1, mu2]
s1 *= temp * temp
s2 *= temp * temp
cov = [[s1 * s1, rho * s1 * s2], [rho * s1 * s2, s2 * s2]]
x = np.random.multivariate_normal(mean, cov, 1)
return x[0][0], x[0][1]
model.sketch_rnn_model.load_weights(weights)
prev_x = np.zeros((1, 1, 5), dtype=np.float32)
prev_x[0, 0, 2] = 1
if z is None:
z = np.random.randn(1, hps.z_size)
batch_z = Input(shape=(hps.z_size,))
initial_state = model.initial_state_layer(batch_z)
decoder_input = Input(shape=(1, 5))
overlay_x = RepeatVector(1)(batch_z)
actual_input_x = concatenate([decoder_input, overlay_x], 2)
decoder_h_input = Input(shape=(hps.dec_rnn_size, ))
decoder_c_input = Input(shape=(hps.dec_rnn_size, ))
output, last_h, last_c = model.decoder_lstm(
actual_input_x,
initial_state=[decoder_h_input, decoder_c_input])
output = model.output_layer(output)
decoder_initial_model = Model(batch_z, initial_state)
decoder_model = Model([decoder_input, batch_z,
decoder_h_input, decoder_c_input],
[output, last_h, last_c])
prev_state = decoder_initial_model.predict(z)
prev_h, prev_c = np.split(prev_state, 2, 1)
strokes = np.zeros((seq_len, 5), dtype=np.float32)
greedy = False
temp = 1.0
for i in range(seq_len):
decoder_output, next_h, next_c = decoder_model.predict(
[prev_x, z, prev_h, prev_c])
out = sketch_rnn_model.get_mixture_coef(decoder_output, model.n_out)
[o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, o_pen_logits] = out
o_pi = K.eval(o_pi)
o_mu1 = K.eval(o_mu1)
o_mu2 = K.eval(o_mu2)
o_sigma1 = K.eval(o_sigma1)
o_sigma2 = K.eval(o_sigma2)
o_corr = K.eval(o_corr)
o_pen = K.eval(o_pen)
if i < 0:
greedy = False
temp = 1.0
else:
greedy = greedy_mode
temp = temperature
idx = get_pi_idx(random.random(), o_pi[0], temp, greedy)
idx_eos = get_pi_idx(random.random(), o_pen[0], temp, greedy)
eos=[0, 0, 0]
eos[idx_eos] = 1
next_x1, next_x2 = sample_gaussian_2d(o_mu1[0][idx], o_mu2[0][idx],
o_sigma1[0][idx], o_sigma2[0][idx],
o_corr[0][idx], np.sqrt(temp), greedy)
strokes[i, :] = [next_x1, next_x2, eos[0], eos[1], eos[2]]
prev_x = np.zeros((1, 1, 5), dtype=np.float32)
prev_x[0][0] = np.array(
[next_x1, next_x2, eos[0], eos[1], eos[2]], dtype=np.float32)
prev_h, prev_c = next_h, next_c
K.clear_session()
return strokes
結果
今回は時間の都合上フクロウのデータセットでのみモデルの学習を行いました.
それではまず入力に用いるスケッチをテストセットからランダムに選んできて,どのようなスケッチか見てみましょう.ちなみにこれは人間がフクロウを描いたものです.
正直フクロウっぽくないですが一応生き物っぽいので良しとします.
次にこのスケッチからエンコーダーによって潜在ベクトルを得ます.
そして最後にデコーダーによってスケッチを生成します.
なかなかフクロウっぽいのではないでしょうか?
それでは次に様々なtemperatureパラメータを用いた時にスケッチがどのように変化していくか見ていきましょう.
右にいくほどtemperatureの値は大きくなります.つまり,よりランダムになっていきます.temperatureが0.1の時はどちらかと言うとペンギンぽいですが,temperatureが0.3と0.5の時はかなりフクロウっぽいスケッチになっています.
かなり荒削りな部分もありますが,スケッチとして認識できるレベルまでしっかり学習ができていることがわかります.ただ,一つ気になったこととしては入力画像にかかわらず常にペンギンのようなスケッチを生成してしまっていたことです.ここの例でも示している通り,入力をかなり無視してペンギンのようなスケッチを生成しています.論文著者のコードではRNNセルにLayer Normalization付きのLSTMやHyperLSTMを用いることができるようになっており,またKL Lossのアニーリングも行なっていたのですが今回の実装ではそれらは含まれていなかったためこのような結果になったのではないかと考えています.入力画像に関係なく毎回同じようなスケッチを生成することに関しては,特にKL Lossのアニーリングを行なっていないことで,Reconstruction Lossに比べてKL Lossにばかり重点が置かれたことが原因だと考えています.
参照
最後に
尚,エクサウィザーズは20卒向けのAIエンジニアのポジションで,内定直結型インターンを東京・京都オフィスで募集しています.ご興味を持たれた方はぜひご応募ください.
ご応募はこちらから