エクサウィザーズ Engineer Blog

株式会社エクサウィザーズのエンジニアチームブログ

画像の内容をAIが文章で答えるデモ

f:id:kentaro-suto:20190124195850j:plain

こんにちは。エクサウィザーズAIエンジニアの須藤です。

エクサウィザーズ が提供しているAIプラットフォームexaBaseに、「画像の内容をAIが答える」という事例があります。 画像解析とテキスト生成という、ディープラーニングでも特に成功している分野の組み合わせであり、いかにも人工知能らしい応用例となっています。

今回、これをWebブラウザ環境に移植できたので報告します。 いつもよりモデルが大規模なため動作環境が限られるのですが、特別な設定などは必要ありませんので、気軽に試していただきたいと思います。

作り方

このアプリケーションは、画像解析を担当する学習済みInceptionモデル*1と、テキスト生成を担当するLSTMモデルの2つで構成されています。 処理の流れは以下の通りです。

  1. 必要なモデルを読み込む。
  2. 画像を取得する。
  3. 画像をデータに変換する。
  4. 画像データをInceptionモデルに与え、特徴量を得る。
  5. LSTMモデルに、特徴量と開始コードを与え、次の単語の確率を推論する。
  6. 最も確率の大きい単語を選び、画面に書き出す。
  7. LSTMモデルに、特徴量と選んだ単語の番号を与え、次の単語の確率を推論する。
  8. 終端コードが選ばれるまで4と5を繰り返す。

元のコードはPython + Kerasで書かれていました。 これをJavaScript + Tensorflow.jsに書き直します。 読み込みや推論など、元々フレームワークを使用していた部分は、JavaScriptでも同等の機能で置き換えることができます。 その他のPythonで直接記述されていた部分は、JavaScriptに翻訳する必要があります。

また、オリジナルのモデルでMergeレイヤーが使われていましたが、これはTensorflow.jsがまだ扱えないクラスです。 Concatenateレイヤーを使ってモデルを書き直す必要がありました。

動作環境

モデルの規模が大きくメモリを大量に消費するためか、動作できる環境が限られます。 参考までに私の環境での検証結果を示します。 ハード環境によっては異なる結果になるかもしれません。

プラットフォーム ブラウザ バージョン 対応状況
macOS Safari 12 × 実行途中で再読み込み
macOS Chrome 71
macOS Firefox 64
macOS Opera 57 × 読み込みでエラー
Windows IE 11 × JavaScriptが未対応
Windows Edge 42 x 読み込み中に再読み込み
Windows Chrome 71 × 読み込みでエラー
Windows Firefox 64
Windows Opera 58 × 読み込みでエラー

使い方

  1. このページにアクセスしてください。 https://base.exawizards.com/view/modelDetail?id=41

  2. ページの中ほどのリンクをクリックして、しばらくお待ちください。 f:id:kentaro-suto:20190124182906p:plain

  3. 2つのモデルを別々に読み込んでいます。このように表示されたら折り返し地点です。 f:id:kentaro-suto:20190124185412p:plain

  4. UIが表示されたら、サンプル画像を選択するか、ローカルファイルを読み込ませてください。 f:id:kentaro-suto:20190124185754p:plain

5.画像を説明するテキストが表示されます。 f:id:kentaro-suto:20190124185923p:plain

結果の例

手持ちの画像で検証しました。

画像 テキスト ひとこと
f:id:kentaro-suto:20190124190550j:plain two white and white cat standing on top on the roof 白茶でなくて白白なのが惜しいですが、ほぼ合っています。
f:id:kentaro-suto:20190125133238j:plain cows are standing on the grass near the grass たまに同じ言葉を繰り返します。
f:id:kentaro-suto:20190124191659j:plain the snow covered street is covered with snow 雪が積もりまくっていますね。
f:id:kentaro-suto:20190124192218j:plain the cake is on the plate with the other fruit ばななが入ったお菓子、正解です。
f:id:kentaro-suto:20190124190820j:plain an open box with some food on it 国籍も不明だし、なるほどなんらかの食べ物としか言えませんね。
f:id:kentaro-suto:20190124192613j:plain there is some food that is on the table とにかく食べ物なことはわかるようです。
f:id:kentaro-suto:20190124192917j:plain the yellow flower is growing on the side オレンジの花が嫌いなんでしょうか。
f:id:kentaro-suto:20190124193151j:plain cars parked on the side walk near trees うまくいっている例です。
f:id:kentaro-suto:20190124193727j:plain the seagull is standing on the boardwalk near the stairs 階段も遊歩道も実は近くにあります。「ありそうな雰囲気」で判断しているのかもしれません。
f:id:kentaro-suto:20190124193331j:plain an image shows an image on the page to describe 説明できていない感じが伝わります。

まとめ

ブラウザ上で動く、画像の内容をAIが答えるデモを作りました。

尚、エクサウィザーズは優秀なエンジニア、社会課題を一緒に解決してくれる魔法使い”ウィザーズ”を募集しています。ご興味を持たれた方はぜひご応募ください。 採用情報|株式会社エクサウィザーズ

ExaWizards Engineer Blogでは、定期的にAIなどの技術情報を発信していきます。Twitter (https://twitter.com/BlogExawizards) で更新情報を配信していきますので、ぜひフォローをよろしくお願いします!

*1:Inceptionモデルの重みは Apacheライセンスの下で公開されています

ブラウザで動く落書き判定モデルの作り方

はじめに

こんにちは。エクサウィザーズAIエンジニアの須藤です。

昨年の弊社忘年会の出し物として、落書き判定モデルを作りました。 お題に合わせて絵を書いて、AIにそれと判定させたら勝ちになるゲームです。 思いのほかちゃんと判定してくれて、ほっとしました。 取り立てて目新しさはありませんが、皆さんにも遊んでいただきたいと思い、ここで紹介します。

exaBaseのモデル詳細ページで実際に遊べます。 ブラウザだけで動作しますので、お気軽にお試しください。

データセット

Googleが提供しているQuick, Draw!というゲームのデータを使用します。 これは、お題に合う絵を描いて、AIに判定してもらうというゲームです。 制限時間は20秒で、AIが候補に挙げた時点でクリアとなります。 世界中のプレーヤーによって描かれた絵が、学習用データベースとして無償で提供されています

データセットには

  • ストローク
  • ストロークを画像(28x28x1)にしたもの
  • カテゴリー
  • 国情報
  • 時刻情報

などが含まれています。今回は画像とカテゴリーだけを用いて学習を行います。 ストロークを利用する学習モデルについては、以前の川畑さんの記事をご覧ください。

画像データはカテゴリー別にNumPy形式のバイナリファイルになっていて、Pythonで

np.load('cat.npy', mmap_mode='r')

とすると読み込むことができます。 1500万人が遊んだというだけあって、約5千万サンプル、約40GBの巨大なデータです。

カテゴリーは以下の通りで、全部で347あります。

エッフェル塔 万里の長城 モナリザ 空母 飛行機 目覚まし時計 救急車 天使 動物の移動 蟻 金床 りんご 腕 アスパラガス 斧 バックパック バナナ 包帯 納屋 バット 野球 バスケット バスケットボール コウモリ バスタブ ビーチ くま あごひげ ベッド 蜂 ベルト ベンチ 自転車 双眼鏡 鳥 誕生日ケーキ ブラックベリー ブルーベリー 本 ブーメラン ボトルキャップ 蝶ネクタイ ブレスレット 脳 パン 橋 ブロッコリー ほうき バケツ ブルドーザー バス ブッシュ 蝶 サボテン ケーキ 電卓 カレンダー ラクダ カメラ 迷彩 キャンプファイヤー ろうそく 大砲 カヌー 車 人参 城 ネコ 天井ファン 携帯電話 チェロ 椅子 シャンデリア 教会 サークル クラリネット 時計 雲 コーヒーカップ コンパス コンピューター クッキー 冷却装置 ソファー 牛 カニ クレヨン クロコダイル 王冠 遊覧船 カップ ダイヤモンド 食器洗い機 飛び込み台 犬 イルカ ドーナツ ドア ドラゴン 化粧ダンス ドリル ドラム アヒル ダンベル 耳 肘 象 封筒 消しゴム 眼 めがね 面 扇風機 羽 柵 指 消火栓 暖炉 消防車 魚 フラミンゴ 懐中電灯 ビーチサンダル フロアランプ 花 空飛ぶ円盤 足 フォーク カエル フライパン 庭用ホース 庭 キリン あごひげ ゴルフクラブ ぶどう 草 ギター ハンバーガー ハンマー 手 ハープ 帽子 ヘッドホン ハリネズミ ヘリコプター ヘルメット 六角形 ホッケーパック ホッケースティック 馬 病院 熱気球 ホットドッグ 温水浴槽 砂時計 観葉植物 家 ハリケーン アイスクリーム ジャケット 刑務所 カンガルー 鍵 キーボード 膝 ナイフ はしご ランタン ノートパソコン 葉 脚 電球 ライター 灯台 稲妻 線 ライオン 口紅 ロブスター ロリポップ メールボックス 地図 マーカー マッチ メガホン マーメイド マイクロフォン 電子レンジ 猿 月 蚊 バイク 山 マウス 口ひげ 口 マグカップ キノコ 爪 ネックレス 鼻 海洋 八角形 たこ 玉ねぎ オーブン ふくろう ペンキ缶 絵筆 ヤシの木 パンダ ズボン ペーパークリップ パラシュート オウム パスポート 落花生 梨 豆 鉛筆 ペンギン ピアノ ピックアップトラック 額縁 豚 枕 パイナップル ピザ ペンチ 警察車 池 プール アイスキャンデー はがき じゃがいも コンセント 財布 ウサギ アライグマ 無線 雨 虹 レーキ リモコン サイ ライフル 川 ジェットコースター ローラースケート ヨット サンドイッチ のこぎり サックス スクールバス はさみ サソリ ドライバー ウミガメ シーソー 鮫 羊 靴 ショーツ シャベル シンク スケートボード 頭蓋骨 超高層ビル 寝袋 笑顔 かたつむり ヘビ スノーケル スノーフレーク 雪だるま サッカーボール 靴下 快速艇 クモ スプーン スプレッドシート 四角 殴り書き リス 階段 星 ステーキ ステレオ 聴診器 縫い目 一時停止標識 コンロ イチゴ 街路灯 サヤインゲン 潜水艦 スーツケース 太陽 白鳥 セーター スイングセット 剣 注射器 Tシャツ テーブル ティーポット テディベア 電話 テレビ テニスラケット テント 虎 トースター つま先 トイレ 歯 歯ブラシ 歯磨き粉 竜巻 トラクター 信号機 列車 木 三角形 トロンボーン トラック トランペット 傘 下着 バン 花瓶 バイオリン 洗濯機 スイカ ウォータースライダー 鯨 ホイール 風車 ワインボトル ワイングラス 腕時計 ヨガ シマウマ ジグザグ

画像を可視化してみると、かなり雑で記号的な絵になっていることがわかります。

f:id:kentaro-suto:20190110180140j:plain f:id:kentaro-suto:20190110180208j:plain f:id:kentaro-suto:20190110180218j:plain f:id:kentaro-suto:20190121170519j:plain f:id:kentaro-suto:20190121170607j:plain f:id:kentaro-suto:20190121170631j:plain f:id:kentaro-suto:20190121170716j:plain f:id:kentaro-suto:20190121170745j:plain f:id:kentaro-suto:20190121170815j:plain f:id:kentaro-suto:20190121170851j:plain

モデル

畳み込みと全結合による、ごくシンプルなモデルです。 Python+Kerasで実装しました。

model = Sequential()
model.add(Reshape((28,28,1), input_shape=(784,)))
model.add(Conv2D(64, 5, padding='same')) #畳み込み
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Conv2D(64, 5, padding='same'))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling2D())
model.add(Conv2D(128, 5, padding='same'))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Conv2D(128, 5, padding='same'))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling2D())
model.add(Conv2D(128, 5, padding='same'))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Conv2D(128, 5, padding='same'))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Flatten()) #ここから全結合
model.add(Dense(1024, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(345, activation='softmax'))
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

学習

データセットからランダムにサンプルを選んで、入力(画像データ)から出力(カテゴリーごとの確率)が得られるように、学習を行います。

学習量は

  • バッチサイズ=100
  • ステップ数=40000
  • エポック数=100
  • のべ=4億サンプル

でした。

インターフェイスと使い方

使い方を、実装詳細とともに解説します。 インターフェイスはHTML+JavaScript+Tensorflow.jsで作成しました。 Tensorflow.jsの基本的な使い方については以前の記事をご参照ください。

1.ページを開く

このページにアクセスします。 ページを開くとすぐに読み込みが始まります。 ダウンロードデータは全部で8.4MBあります。ネットワーク環境によってはお時間をいただくかもしれません。

2.絵を描く

ページ全体がキャンバスになっています。 ドラッグで線を描いてください。 f:id:kentaro-suto:20190111194202p:plain

判定時には線が描かれた領域だけがモデルに合わせてリサイズされるので、どこにどんな大きさで描いても大丈夫です。 間違えたら「リセット」で全消去できます。

リサイズで線の太さがバラつかないように、線を描画するとき、同時にストロークデータとしても保存しています。

3.判定させる

「送信」ボタンを押すと判定処理を開始します。

保存したストロークデータを用いて28×28の領域に再描画をします*1。 描いたピクセルを配列に変換し*2、モデルに入力します。

4.結果を見る

判定結果を表示します。

f:id:kentaro-suto:20190111195534p:plain

モデルの出力は各カテゴリーの確率として得られます。 その中で、最も大きなものに対応するカテゴリー名を表示します。 他のカテゴリーの確率は、詳細をクリックすると見られるテーブルに書き出します。

実行例

f:id:kentaro-suto:20190111200253p:plain f:id:kentaro-suto:20190111200309p:plain f:id:kentaro-suto:20190111200346p:plain f:id:kentaro-suto:20190115103611p:plain
これはドアで、 これが枕です。 四角と判定させるには、場違いな正確さが要求されます。 円については、そこまでシビアではありません。
f:id:kentaro-suto:20190111200204p:plain f:id:kentaro-suto:20190115103907p:plain f:id:kentaro-suto:20190111203526p:plain f:id:kentaro-suto:20190115101151p:plain
うっかりネコにしましまを描くと、高確率で虎になります。 学習データ全部、殴り書きのようなものなので、 リアルさにこだわると、却って判別されないことがあります。 極限まで図案化するのがいいようです。
f:id:kentaro-suto:20190115143611p:plain f:id:kentaro-suto:20190111202410p:plain f:id:kentaro-suto:20190115143713p:plain f:id:kentaro-suto:20190111201550p:plain
よくわからないなにかを描いても 必ず、347カテゴリーのどれかに判定します。 難しそうなカテゴリーも 意外に簡単な特徴で判定している場合があります。

まとめ

落書き判定モデルを作りました。 花とか月といった、誰でも同じようになる絵は高い精度で認識します。 期待通りに認識しない場合もありますが、間違い方に妙な納得感があったりします。 機械学習の面白くて不思議なところです。

尚、エクサウィザーズは優秀なエンジニア、社会課題を一緒に解決してくれる魔法使い”ウィザーズ”を募集しています。ご興味を持たれた方はぜひご応募ください。 採用情報|株式会社エクサウィザーズ

ExaWizards Engineer Blogでは、定期的にAIなどの技術情報を発信していきます。Twitter (https://twitter.com/BlogExawizards) で更新情報を配信していきますので、ぜひフォローをよろしくお願いします!

*1:このとき学習データに合わせて線の太さを1.5ピクセルにするのが重要です。 1ピクセルではまともに判定しません。 この対策を思いつくまでに大分手こずりました。

*2:このとき背景が0になるように、明るさの反転も行います。

いつでも心電図が取れたとして? CES 2019 レポート 3日目

こんにちは。ロボットチーム、ソフトウェアエンジニアの鳥居です。

昨日に引き続き、CES 2019のレポートです。エンジニア視点で面白かったものを紹介します。

ウェアラブルデバイスによる医療データの収集

f:id:eitaro-torii-exwzd:20190111184142j:plain
バイオメトリックセンサが入った様々なウェアラブルデバイス (Valencell)

医療に関するセッションや医療・健康に関するプロダクトの展示も多数ありました。その中の1つで「ヒアラブル(hearable)」デバイスの話が面白かったので紹介します。

ヒアラブルデバイスはウェアラブルデバイスの1種で、音楽を聴くために身に付けるもののことを指します。ヒアラブルデバイスは肌に直接触れるので、センサーを取り付けることで健康状態に影響する様々な情報を取得することができます。このようなデバイスは年々増えており、また取得されるデータも検査のタイミングだけでなく身につけている限り連続的に取得できるようになると思われます。より質の良いデータを得ることで、医師や患者のコミュニケーションが円滑になったり、患者自身がより自分の身体を知ることが期待されます。一方、データの増加に伴って重要になるのは「必要な情報をいかに絞り込むか」という点ですが、これは今後発展していく分野と考えられています。

ヒアラブルデバイスが特徴的なのは、このようなウェアラブルデバイスの中でも多くの割合を占めていることと、他にも活用の場面があると考えられていることです。耳に埋め込むデバイスはその内側から「本人が喋った声」を抽出することができるので音声入力のデバイスとして使用したり、聞き分けたい音を動的に切り替えるような「Augmented hearing」などの方向性が考えられています。

写真は Valencell という会社がセンサを提供しているウェアラブルデバイスを集めた展示のものです。既に多くの種類のデバイスで自身の健康状態が把握できるようになっています。

オピオイド問題への解決策

f:id:eitaro-torii-exwzd:20190111183902j:plain
神経痛と対応する脊椎の位置を示す展示 (NANS)

アメリカでは現在、麻薬系鎮痛剤を多量に処方されてしまうことで依存症や過剰摂取による死亡事故が発生しており、この問題はオピオイド問題と呼ばれています。オピオイドは特に慢性の痛みに対して処方されることが多く、オピオイドを使わない方法として神経に電気的な刺激を与えることで慢性の痛みを軽減するアプローチが提案されています。写真は神経痛を抑えるために脊椎のどの位置を刺激すればよいかを表したモデルで、この位置に電気的な刺激を与えるデバイスを埋め込んで薬の代わりに電気刺激で痛みを和らげることができます。アメリカでのみ問題となっていることですが、技術で社会的な問題に挑戦する姿勢は素晴らしいと思います。

エクサウィザーズでは医療の問題を解決したいエンジニアを募集しています。ご興味を持たれた方はぜひご応募ください!採用情報|株式会社エクサウィザーズ

また、ExaWizards Engineer Blogでは、定期的にAIなどの技術情報を発信していきます。Twitter (https://twitter.com/BlogExawizards) で更新情報を配信していきますので、ぜひフォローをよろしくお願いします!

5Gの可能性とは CES 2019 レポート 2日目

こんにちは。ロボットチーム、ソフトウェアエンジニアの鳥居です。

昨日に引き続き、CES 2019のレポートです。エンジニア視点で面白かったものを紹介します。

5Gとエッジコンピューティング

CESでは展示の他に、業界のリーダーが集まって特定のトピックについて話をするイベントが多数開催されています。5Gとエッジコンピューティングについてのセッションが面白そうだったので聞きに行ってみました。

興味深かった話をざっくり箇条書きで:

  • 5Gによる通信が一般で使われるようになると「どこで計算するか」の依存性が薄くなる
    • レイテンシが小さくなりスループットが上がるので、データの受け渡しがすぐにできるようになる
    • 例えば、現在の AR/VR用のヘッドマウントディスプレイは本体に計算処理を行うモジュールが含まれるのでサイズが大きめだが、このモジュールを外して小さくすることができる
  • 計算のトレンドは振り子のように揺れている
    • (1980年代)メインフレーム,集中型 -> (1990年代) PC, 分散型 -> (2000年代) クラウド, 集中型 -> (20XX年代) エッジコンピューティング, 分散型
    • 消費者が持っている計算リソースが年々増加しており、高速なlast 1 mileの通信の登場によってエッジコンピューティングのメリットが大きくなる。

IBMの量子コンピューター

昨日のキーノートで、https://newsroom.ibm.com/2019-01-08-IBM-Unveils-Worlds-First-Integrated-Quantum-Computing-System-for-Commercial-Use#assets_all:IBMが商用量子コンピュータについてのアナウンスがありました。IBMのブースに模型が展示されてますので行ってみました。

f:id:eitaro-torii-exwzd:20190110134721j:plain
IBM Q System One の模型

研究室の外で動く商用量子コンピュータは世界初ということで、次の世代のコンピューティングを始めようとする意気込みが伝わってきました。

エクサウィザーズでは新しい時代のエンジニアリングに挑戦したいエンジニアを募集しています。ご興味を持たれた方はぜひご応募ください!採用情報|株式会社エクサウィザーズ

また、ExaWizards Engineer Blogでは、定期的にAIなどの技術情報を発信していきます。Twitter (https://twitter.com/BlogExawizards) で更新情報を配信していきますので、ぜひフォローをよろしくお願いします!

次はVPUの時代が来る? CES 2019 レポート 1日目

こんにちは。ロボットチーム、ソフトウェアエンジニアの鳥居です。

新しいプロダクトのアイデアを得るために、ロボットチーム3名でコンシューマー・エレクトロニクス・ショー(CES) というイベントに参加しています。アメリカのラスベガスで毎年開催されているイベントで、今年の一般公開は1/8〜1/11の日程となっています。

この記事では、エンジニア視点で面白かったものをピックアップして紹介していきたいと思います。

AIセキュリティカメラ、VPU

f:id:eitaro-torii-exwzd:20190109105311j:plain
CES 2019 Innovation Awardを受賞した SimCam

家庭用の監視カメラですが、顔識別やネットワークの機能が備わっているものです。画像の識別の機能はクラウドベースで提供されているものが多いですが、このプロダクトは小さなカメラの中にVisual Processing Unit (VPU)を搭載しているため、クラウドへのデータアップロードが不要となっています。クラウド利用のコストが不要になることやプライバシー保護の観点で優れた特徴です。Google Assistant や Amazon Alexaとも連携ができ、その場合でも画像をアップロードすることなく動作することができます。展示では「知らない人が玄関に近づいたらAlexaが声を掛ける」という使用例が出ていました。

面白いなと思ったのは、既存ハードウェアでは解決できてない問題を新しい種類のハードウェアを使って解決しに行っているところです。AIを使った計算というとGPUを使うものと思ってしまいますが、識別を主に行うプロダクトにはGPUは過剰なリソースとなり消費電力などの問題を抱えてしまいます。VPUはカメラから取得した画像を処理するのに特化した計算機で、画像向けの機械学習の計算に適した構造のハードウェアとのことです。カメラの中にVPUを積むアイデアがいくつもの問題を解決していてプロダクトとしての魅力を高めていることにとても面白いと思いました。

展示会場ではVPUというキーワードを他にも聞くことができました。画像処理AIの消費者向けプロダクトは、VPUがキーワードになるのかもしれません。

オムロンのパレタイジングと卓球

f:id:eitaro-torii-exwzd:20190109113313j:plain
卓球ロボット (オムロン)

f:id:eitaro-torii-exwzd:20190109112726j:plain
パレタイズ (オムロン)

ロボットの展示コーナーで一番最初に見えてくるのがオムロンの卓球ロボットでした。ロボットアームの先にラケットが付いていて、人がロボットと対戦すると対戦終了後に「ここをこうするともっと上手くなる」というチャートが表示されるという展示でした。その横で同じロボットを使ったパレタイジングのデモがあるのですが、パレタイジングがとても速い!目を引く卓球の展示と、単体だとちょっと地味なパレタイジングのデモを同じロボットでやっているのをみて、見せ方が面白さを感じました。

エクサウィザーズでは新しいプロダクトを作ることに興味のあるエンジニアを募集しています。ご興味を持たれた方はぜひご応募ください!採用情報|株式会社エクサウィザーズ

また、ExaWizards Engineer Blogでは、定期的にAIなどの技術情報を発信していきます。Twitter (https://twitter.com/BlogExawizards) で更新情報を配信していきますので、ぜひフォローをよろしくお願いします!

ブラウザで動くAI類語地図

こんにちは。エクサウィザーズAIエンジニアの須藤です。

類語辞典って便利ですよね。 書いた文章がしっくりこないときに、ニュアンスの違う単語に置き換えたり、和語と漢語と外来語を入れ替えたりできます。

しかし、適当な表現が出てこなくて、もっと漠然と言葉を探したい時はどうしましょう。 辞書を繰り返し引いて、類語の類語の類語を見て回るのはちょっと面倒ですね。 かと言って項目あたりの類語数が増えたとしても、その中で探す手間が増えて、やはり使いづらそうです。

そこで、こういうものを作ってみました。 f:id:kentaro-suto:20181128154227p:plain https://base.exawizards.com/view/modelDetail?id=44

詳細を以下で解説します。

概要

類語を平面に分散して表示するプログラムです。 JavaScriptで書かれていますので、最近のブラウザがあれば動きます。 類語は見出し語に似ているものほど中央に、かつ類語同士でも似たものほど近くに表示されるようになっています。 つまり、よりしっくりくる単語の周りを優先して探せば、素早く目的の単語を見つけられることになります。 ...ということを期待して作りました。

単語の位置は、Word2vecという手法で得られた意味ベクトルから計算して求めています。 ここで意味ベクトル*1とは、単語の意味を数値化したものです。 関連の深い単語同士はベクトルの値も近くなります。 実際のテキストから、後述する機械学習によって得られます。

学習

  1. 前処理
    • テキストを形態素解析エンジンであるMeCabで分解します。
    • 単語ごとに登場頻度を調べ、高い順に番号をつけます。
    • テキストを単語番号の列に変換します。
  2. 学習
    • 単語数とベクトルの次元数を決めておきます。
    • 前後の単語が与えられたときに真ん中の単語がどれかを選択肢の中から予測する、以下の学習モデル*2を構築します。 f:id:kentaro-suto:20181122170033p:plain
    • 予測自体に有用性は特にありません。欲しいのは副産物である重みパラメータの値です。
    • 下記の設定でソースごとに学習を行いました。
ソース ウィキペディア 青空文庫 小説家になろう*3
MeCab辞書 NEologd デフォルト NEologd
テキスト量 約6億語 約5000万語 約2億5000万語
全ての単語の種類 182,669 108,676 227,685
学習した単語の種類 100,000 100,000 100,000
ベクトル次元 300 300 300
学習回数*4 5億 2.8億 6.3億
  • 後処理
    • 重みWは単語数×ベクトルの次元を持った行列です。これがそのまま各単語の意味ベクトルを表します。
    • 単語とベクトルの関係をオブジェクトリテラル形式に書き出し、可視化プログラムで使えるようにします。

可視化

以下の方針で単語の表示位置を決定します。

  • 見出し語を画面の中心に配置する。
  • 見出し語に似た単語ほど、中央近くに配置する。
  • 似ていない単語同士が、できるだけ離れるようにする。

具体的には以下のような計算を行います。

  • 全ての見出し語の意味ベクトルを調べ\{a_i\}とし、その平均aを求めます。これが画面の中心にあたります。
  • aに近いベクトルを30個選び\{b_i\}とし、その平均bも求めます。これらが分布している主な方向を以下で求めます。
  • a-bと直交し分散 \overline{(\xi\cdot(b_i-b))^2}を最大にする方向ベクトル\xiを求めます。これがX方向になります。
  • a-bおよび\xiと直交し分散 \overline{(\psi\cdot(b_i-b))^2}を最大にする別の方向ベクトル\psiを求めます。これがY方向になります。
  • 意味ベクトルがv_iになる単語の画面中心からの距離をr={\sigma}\left(\frac{||a|| ||v_i||}{a\cdot{v_i}}\right)^6とします。式の形および指数6に深い意味はありません。色々試した中で表示結果が見やすかったものを選びました。{\sigma}はスケールで、\{b_i\}に対応する単語が画面にちょうど収まるように調節します。
  • 画面中心からの方向ベクトルを(\hat x_i,\hat y_i)=\frac{(\xi,\psi)\cdot(v_i-b)}{||(\xi,\psi)\cdot(v_i-b)||}で求めます。
  • 最後に座標を、(x_i,y_i)=r (\hat x_i,\hat y_i)+{\mu}で求めます。{\mu}は画面の中心の座標です。

使い方

  1. ソースの選択
    • ポップアップボタンで学習ソースを選択してください。
    • 選択すると、データの読み込みが始まります。データファイルは10〜20MBあるので、読み込みには時間がかかります。
  2. 見出し語の入力
    • テキストフィールドに単語を入力してください。
    • 複数の単語を入力できます。区切り文字が無くても、自動的に知っている単語に分解します。スペースなどで明示的に分解もできます。
    • 自立語*5が対象です。動詞と形容詞は終止形、形容動詞は語幹にしてください*6
  3. 画面の操作
    • 画面をドラッグするとスクロールできます。
    • 右下のスライダーを操作すると、ズームできます。
    • 単語をダブルクリックすると、その単語を新たな見出し語として再描画します。
  4. その他
    • 学習ソースと見出し語はアドレスバーに逐次反映されます。URLをコピーして別のブラウザで開くと、状態が再現されます。

使用例

ソースの違い

学習ソースを変えると、異なる観点で類語を検索することができます。

ウィキペディア 青空文庫 小説家になろう
f:id:kentaro-suto:20181122141733p:plain f:id:kentaro-suto:20181122141757p:plain f:id:kentaro-suto:20181122141808p:plain
妙に偏った選別になりました。題名などで繰り返し使われると影響を受けやすいようです。 犬と猫は使われる文脈が似ているため、類語あつかいになってしまいます。 様々な猫の呼び方が出てきます。
タンク f:id:kentaro-suto:20181122143214p:plain f:id:kentaro-suto:20181122143223p:plain f:id:kentaro-suto:20181122143305p:plain
水槽関連、燃料関連に分かれました。 燃料関連、戦車関連です。 タンク職関連が大勢を占めます。
四天王 f:id:kentaro-suto:20181122150948p:plain f:id:kentaro-suto:20181122150956p:plain f:id:kentaro-suto:20181122151004p:plain
比喩的な使われ方のほか、元祖四天王もカバーします。 仏教関連のほか、称号の例として頼光四天王が出てきました。 魔王の部下として倒される存在です。

興味深い例

制作の過程で気がついた、興味深い例を紹介します。 学習結果は随時アップデートしているので、現在のデータで試してもこの通りにならないことをご了承ください。

東京 京都 ローマ
f:id:kentaro-suto:20181122153111p:plain f:id:kentaro-suto:20181122153459p:plain f:id:kentaro-suto:20181122153729p:plain
東京が付く言葉、東京都の下部組織、東京の地名などに分かれます。 京都府の地名、京都市の地名、観光地に分かれます。 イタリアの各都市、古代ローマの詳細、他の文明などが出てきます。
家康 正義 しかし
f:id:kentaro-suto:20181122153909p:plain f:id:kentaro-suto:20181122154041p:plain f:id:kentaro-suto:20181122154322p:plain
戦国時代の家康、江戸時代の家康、安土桃山時代の家康に別れました。 正しい行いという意味の言葉と、名前が正義の人の苗字が出てきます。 接続詞なども入っています。

最後に

Word2vecの二次元可視化によって、言葉探しが楽しく便利に行えるようになる可能性を示しました。 こちらですぐ試せます。

尚、エクサウィザーズは優秀なエンジニア、社会課題を一緒に解決してくれる魔法使い”ウィザーズ”を募集しています。ご興味を持たれた方はぜひご応募ください。 採用情報|株式会社エクサウィザーズ

ExaWizards Engineer Blogでは、定期的にAIなどの技術情報を発信していきます。Twitter (https://twitter.com/BlogExawizards) で更新情報を配信していきますので、ぜひフォローをよろしくお願いします!

*1:埋め込みベクトルまたは分散表現とも言います

*2:Negative Samplingという技術の変形です

*3:「小説家になろう」は 株式会社ヒナプロジェクトの登録商標です

*4:バッチサイズ×ステップ数×エポック数

*5:助詞と助動詞以外

*6:例:長く→長い、のびろ→のびる、毛むくじゃらな→毛むくじゃら

Sketch-RNN でスケッチの自動生成(VAE + LSTM)

こんにちは.エクサウィザーズでインターンをしている川畑です.

視覚によるコミュニケーションというのは人々が相手に何らかのアイデアを伝える際に鍵となります.私たちは小さい頃から物体を描く力を養ってきており,時には感情までもたった複数の線で表現することも可能です.こうした単純な絵というのは,身の回りのものを写真のように捉え忠実に再現したものではなく,どのようにして人間が物体の特徴を認識しそれらを再現するか,ということを教えてくれます.

そこで今回はSketch-RNNと呼ばれるRecurrent Neural Networkモデルでのスケッチの自動生成に取り組んでみました. このモデルは人間がするのと同じように抽象的な概念を一般化し,スケッチを生成することを目的としたものです.このモデルに関しては今の所具体的なアプリケーションが存在するというわけではなく,機械学習がどのようにクリエイティブな分野で活用できるか,という一例を提案したものになります.ソースコードはこちらからダウンロードできます.

f:id:k_kawabata:20181027163327p:plain [1] Google AI Blog: Teaching Machines to Draw

データセット

今回は論文でも用いられていたQuick, Draw!のデータセットを使用しました.これは20秒以内であるお題の絵を描くゲームで,データセットとしてネコやブタ,バスなど数百個のクラスのデータを公開しています.各クラスのデータは,訓練,検証,テストとしてそれぞれ70000,2500,2500のデータに分かれています.

Sketch-RNNモデルを学習させるにあたり,データのフォーマットとしては各点のデータが5つの要素からなるベクトル \left(\Delta x, \Delta y, p_{1}, p_{2}, p_{3} \right)を使用しました.最初の二つの要素は,前の点からの x, yの変位,残りの三つの要素はone-hotベクトルとなっています.一つ目のペン状態 p_{1}はペンが現在紙に接していて,次の点まで線が描かれることを示しており,二つ目のペン状態 p_{2}はそこでペンが紙から離れ,次の点まで線は引かれないことを示しており,三つ目のペン状態 p_{3}はスケッチが終了することを示しています.

Sketch-RNNモデル

それではモデルについて解説していきたいと思います.

[2] A Neural Representation of Sketch Drawings

f:id:k_kawabata:20181027164205p:plain

モデルの大枠はSequence-to-Sequence Variational Autoencoder (VAE) でできています.まず,エンコーダーであるbidirectional RNN (今回は単純なLSTMを使用)にスケッチのシーケンスを入力し,出力として潜在ベクトル zを得ます.具体的には,双方向のRNNから得られた隠れ状態を連結し,全結合層によって連結された隠れ状態から \muおよび \hat{\sigma}を得ます. \hat{\sigma}に関しては,負にならないように \sigma = \exp \left ({\frac{\hat{\sigma}}{2}} \right)の操作を加えます.そしてさらにガウス分布 \mathcal{N}(0, I)と組み合わせることで最終的に潜在ベクトルを計算することができます.式で書くと以下のようになります.

 \displaystyle \mu = W_{\mu}h + b_{\mu}, \;  \hat{\sigma} = W_{\sigma}h + b_{\sigma}, \; \sigma = \exp \left ({\frac{\hat{\sigma}}{2}} \right), \; z = \mu + \sigma\odot\mathcal{N}(0, I)

潜在ベクトルが得られたら,次はデコーダーです.隠れ状態の初期値には潜在ベクトルに tanh関数を掛けたものを用います.

 \displaystyle \left[h_{0}; c_{0} \right] = tanh(W_{z}z + b_{z})

各ステップでのデコーダーの入力には前ステップでのストローク S_{i-1}と潜在ベクトル zを連結したものを与え, S_{0}にはスケッチの開始を意味する (0, 0, 1, 0, 0)を与えるようにします.各ステップでの出力は次のデータ点の確率分布に関するパラメータを返します.Sketch-RNNでは M個の正規分布からなるGaussian mixture model (GMM)により \left(\Delta x, \Delta y \right)を,カテゴリカル分布 \left(q_{1}, q_{2}, q_{3} \right) \left(q_{1} + q_{2} + q_{3} = 1 \right)により真値である \left(p_{1}, p_{2}, p_{3} \right)をモデル化します.

 \displaystyle p(\Delta x, \Delta y) = \sum_{j=1}^{M} \Pi _{j} \mathcal{N} \left(\Delta x, \Delta y \mid \mu_{x, j}, \mu_{y, j}, \sigma_{x, j}, \sigma_{y, j}, \rho_{xy, j} \right), where \sum_{j=1}^{M} \Pi_{j} = 1

上の式において \mathcal{N} \left(x, y \mid \mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy} \right)は2変量正規分布を表しており,5つのパラメータ \left(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy} \right)で構成されます. \mu_{x}, \mu_{y}は平均,  \sigma_{x}, \sigma_{y}は標準偏差, \rho_{xy} x, yの相関係数です.また, \Piは長さ Mのベクトルであり,Gaussian mixture modelの各正規分布の重みに相当します.ですので,出力のサイズは3つのロジット \left(q_{1}, q_{2}, q_{3} \right)も含めて合計で 5M + M + 3となります.

 \displaystyle x_{i} = \left[S_{i-1}; z \right], \left[h_{i}; c_{i} \right]

 y_{i} = W_{y}h_{i} + b_{y}, \; y_{i} \in \mathbb{R}^{6M+3}

ここで,出力ベクトル y_{i}は次のように分解されます.

 \displaystyle \left[ (\hat{\Pi}_{1} \, \mu_{x} \, \mu_{y} \, \hat{\sigma}_{x} \, \hat{\sigma}_{y} \,  \hat{\rho}_{xy})_{1} \ldots (\hat{\Pi}_{M} \, \mu_{x} \, \mu_{y} \, \hat{\sigma}_{x} \, \hat{\sigma}_{y} \,  \hat{\rho}_{xy})_{M} \; (\hat{q}_{1} \, \hat{q}_{2} \, \hat{q}_{3})  \right] = y_{i}

標準偏差,相関係数に関してはそれぞれ負にならないように,また-1から1の値を取るように exp tanhによる操作を施します.

 \displaystyle \sigma_{x} = \exp\left(\hat{\sigma}_{x} \right), \; \sigma_{y} = \exp\left(\hat{\sigma}_{y} \right), \; \rho_{xy} = tanh\left(\hat{\rho}_{xy}  \right)

カテゴリカル分布に関してはSoftmax関数を適応させ,全ての値が0〜1に収まるようにします.

 \displaystyle q_{k} = \frac{\exp(\hat{q}_{k})}{\sum_{j=1}^{3} \exp(\hat{q}_{j})}, \; k \in \{1, 2, 3\}, \; \Pi_{k} = \frac{\exp(\hat{\Pi}_{k})}{\sum_{j=1}^{M} \exp(\hat{\Pi}_{j})}, \; k \in \{1, \ldots , M\}

損失関数

損失関数はVAEと同じでReconstruction Loss,  L_{R}とKullback-Leibler (KL) Divergence Loss,  L_{KL}の二つの項からなります.

  • Reconstruction Loss,  L_{R}

Reconstruction Loss項は訓練データ Sを説明する確率分布の対数尤度を表しており,これを最大化するように学習します.そして,Reconstruction Loss,  L_{R}はさらに座標に関する項 L_{s}とペン状態に関する項 L_{p} からなっており,それぞれ 以下のように書くことができます.

 \displaystyle L_{s} = -\frac{1}{N_{max}} \sum_{i=1}^{N_{s}}log \left(\sum_{j=1}^{M}\Pi_{j, i} \mathcal{N} \left(\Delta x_{i}, \Delta y_{i} \mid \mu_{x, j, i}, \mu_{y, j, i}, \sigma_{x, j, i}, \sigma_{y, j, i}, \rho_{xy, j, i} \right)\right)

 \displaystyle L_{p} = -\frac{1}{N_{max}} \sum_{i=1}^{N_{max}} \sum_{k=1}^{3} p_{k, i} log\left(q_{k, i} \right)

 L_{R} = L_{s} + L_{p}

  • Kullback-Leibler (KL) Divergence Loss,  L_{KL}

KL Divergence Loss項は潜在ベクトル zが標準正規分布からどれだけ離れているかを表しており,これを最小化するように学習することになります.

 \displaystyle L_{KL} = -\frac{1}{2N_{z}} \left(1 + \hat{\sigma} - \mu^{2} - \exp(\hat{\sigma}) \right)

実際に最適化する損失関数には L_{R} L_{KL}を重み付けして足し合わせたものを用います.

 Loss = L_{R} + w_{KL}L_{KL}

それぞれの損失項にはトレードオフの関係があり, w_{KL} \rightarrow 0の時にはモデルは純粋なオートエンコーダーに近づき,より良いReconstruction Lossを得ることができる一方で,潜在空間における事前分布の強化を犠牲にすることになります.

サンプリング

モデルを学習させた後はいよいよスケッチの生成です.サンプリング過程では各ステップごとにGMMとカテゴリカル分布のパラメータを得,そのステップでの出力 S_{i}'を得ます.訓練過程とは異なり,サンプリング過程では出力 S_{i}'を次のステップの入力とし,このサンプリングステップを p_{3} =1となるかステップ数が N_{max}となるまで繰り返していきます.最終的に出力を得る際にはdeterministicに確率密度関数で最も確率の高い点を選ぶのではなく,下記のようにtemperatureパラメータ \tauを導入することで出力のランダムさを調節できるようにしています. \tauは0〜1の値を取り, \tau = 0の時にはモデルはdeterministicになります.

 \displaystyle \hat{q}_{k} \rightarrow \frac{\hat{q}_{k}}{\tau}, \; \hat{\Pi}_{k} \rightarrow \frac{\hat{\Pi}_{k}}{\tau}, \; \sigma_{x}^{2} \rightarrow \sigma_{x}^{2}\tau, \; \sigma_{y}^{2} \rightarrow \sigma_{y}^{2}\tau

コード

モデルに関する部分のコードを示します.Githubに論文著者によるオリジナルのコードもありますが,オリジナルではtensorflowで書かれていたものをkerasで書き換えました.全コードはこちらからダウンロードできます.

# below is where we need to do MDN (Mixture Density Network) splitting of
# distribution params
def get_mixture_coef(output, n_out):
  """Returns the tf slices containing mdn dist params."""
  # This uses eqns 18 -> 23 of http://arxiv.org/abs/1308.0850.
  z = output
  z = tf.reshape(z, [-1, n_out])
  z_pen_logits = z[:, 0:3]  # pen states
  z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = tf.split(z[:, 3:], 6, 1)

  # process output z's into MDN paramters

  # softmax all the pi's and pen states:
  z_pi = tf.nn.softmax(z_pi)
  z_pen = tf.nn.softmax(z_pen_logits)

  # exponentiate the sigmas and also make corr between -1 and 1.
  z_sigma1 = K.exp(z_sigma1)
  z_sigma2 = K.exp(z_sigma2)
  z_corr = tf.tanh(z_corr)

  r = [z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen, z_pen_logits]
  return r


class SketchRNN():
  """SketchRNN model definition."""

  def __init__(self, hps):
    self.hps = hps # hps is hyper parameters
    self.build_model(hps)

  def build_model(self, hps):
    # VAE model = encoder + Decoder
    # build encoder model
    encoder_inputs = Input(shape=(hps.max_seq_len, 5), name='encoder_input')
    # (batch_size, max_seq_len, 5)
    encoder_lstm = LSTM(hps.enc_rnn_size,
                use_bias=True,
                recurrent_initializer='orthogonal',
                bias_initializer='zeros',
                recurrent_dropout=1.0-hps.recurrent_dropout_prob,
                return_sequences=True,
                return_state=True)
    bidirectional = Bidirectional(encoder_lstm)
    (unused_outputs, # (batch_size, max_seq_len, enc_rnn_size * 2)
    last_h_fw, unused_c_fw, # (batch_size, enc_rnn_size) * 2
    last_h_bw, unused_c_bw) = bidirectional(encoder_inputs)
    last_h = concatenate([last_h_fw, last_h_bw], 1)
    # (batch_size, enc_rnn_size*2)

    normal_init = RandomNormal(stddev=0.001)
    self.z_mean = Dense(hps.z_size,
                  activation='linear',
                  use_bias=True,
                  kernel_initializer=normal_init,
                  bias_initializer='zeros')(last_h) # (batch_size, z_size)
    self.z_presig = Dense(hps.z_size,
                  activation='linear',
                  use_bias=True,
                  kernel_initializer=normal_init,
                  bias_initializer='zeros')(last_h) # (batch_size, z_size)

    def sampling(args):
      z_mean, z_presig = args
      self.sigma = K.exp(0.5 * z_presig)
      batch = K.shape(z_mean)[0]
      dim = K.int_shape(z_mean)[1]
      epsilon = K.random_normal((batch, dim), 0.0, 1.0)
      batch_z = z_mean + self.sigma * epsilon

      return batch_z # (batch_size, z_size)

    self.batch_z = Lambda(sampling,
                    output_shape=(hps.z_size,))([self.z_mean, self.z_presig])

    # instantiate encoder model
    self.encoder = Model(
                    encoder_inputs,
                    [self.z_mean, self.z_presig, self.batch_z], name='encoder')
    # self.encoder.summary()

    # build decoder model
    # Number of outputs is 3 (one logit per pen state) plus 6 per mixture
    # component: mean_x, stdev_x, mean_y, stdev_y, correlation_xy, and the
    # mixture weight/probability (Pi_k)
    self.n_out = (3 + hps.num_mixture * 6)

    decoder_inputs = Input(shape=(hps.max_seq_len, 5), name='decoder_input')
    # (batch_size, max_seq_len, 5)
    overlay_x = RepeatVector(hps.max_seq_len)(self.batch_z)
    # (batch_size, max_seq_len, z_size)
    actual_input_x = concatenate([decoder_inputs, overlay_x], 2)
    # (batch_size, max_seq_len, 5 + z_size)

    self.initial_state_layer = Dense(hps.dec_rnn_size * 2,
            activation='tanh',
            use_bias=True,
            kernel_initializer=normal_init)
    initial_state = self.initial_state_layer(self.batch_z)
    # (batch_size, dec_rnn_size * 2)
    initial_h, initial_c = tf.split(initial_state, 2, 1)
    # (batch_size, dec_rnn_size), (batch_size, dec_rnn_size)
    self.decoder_lstm = LSTM(hps.dec_rnn_size,
            use_bias=True,
            recurrent_initializer='orthogonal',
            bias_initializer='zeros',
            recurrent_dropout=1.0-hps.recurrent_dropout_prob,
            return_sequences=True,
            return_state=True
            )

    output, last_h, last_c = self.decoder_lstm(
                          actual_input_x, initial_state=[initial_h, initial_c])
    # [(batch_size, max_seq_len, dec_rnn_size), ((batch_size, dec_rnn_size)*2)]
    self.output_layer = Dense(self.n_out, activation='linear', use_bias=True)
    output = self.output_layer(output)
    # (batch_size, max_seq_len, n_out)

    last_state = [last_h, last_c]
    self.final_state = last_state

    # instantiate SketchRNN model
    self.sketch_rnn_model = Model(
                  [encoder_inputs, decoder_inputs],
                  output,
                  name='sketch_rnn')
    # self.sketch_rnn_model.summary()

  def vae_loss(self, inputs, outputs):
    # KL loss
    kl_loss = 1 + self.z_presig - K.square(self.z_mean) - K.exp(self.z_presig)
    self.kl_loss = -0.5 * K.mean(K.sum(kl_loss, axis=-1))
    self.kl_loss = K.maximum(self.kl_loss, K.constant(self.hps.kl_tolerance))

    # the below are inner functions, not methods of Model
    def tf_2d_normal(x1, x2, mu1, mu2, s1, s2, rho):
      """Returns result of eq # 24 of http://arxiv.org/abs/1308.0850."""
      norm1 = subtract([x1, mu1])
      norm2 = subtract([x2, mu2])
      s1s2 = multiply([s1, s2])
      # eq 25
      z = (K.square(tf.divide(norm1, s1)) + K.square(tf.divide(norm2, s2)) -
           2 * tf.divide(multiply([rho, multiply([norm1, norm2])]), s1s2))
      neg_rho = 1 - K.square(rho)
      result = K.exp(tf.divide(-z, 2 * neg_rho))
      denom = 2 * np.pi * multiply([s1s2, K.sqrt(neg_rho)])
      result = tf.divide(result, denom)
      return result

    def get_lossfunc(z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr,
                     z_pen_logits, x1_data, x2_data, pen_data):
      """Returns a loss fn based on eq #26 of http://arxiv.org/abs/1308.0850."""
      # This represents the L_R only (i.e. does not include the KL loss term).

      result0 = tf_2d_normal(x1_data, x2_data, z_mu1, z_mu2, z_sigma1, z_sigma2,
                             z_corr)
      epsilon = 1e-6
      # result1 is the loss wrt pen offset (L_s in equation 9 of
      # https://arxiv.org/pdf/1704.03477.pdf)
      result1 = multiply([result0, z_pi])
      result1 = K.sum(result1, 1, keepdims=True)
      result1 = -K.log(result1 + epsilon)  # avoid log(0)

      fs = 1.0 - pen_data[:, 2]  # use training data for this
      fs = tf.reshape(fs, [-1, 1])
      # Zero out loss terms beyond N_s, the last actual stroke
      result1 = multiply([result1, fs])

      # result2: loss wrt pen state, (L_p in equation 9)
      result2 = tf.nn.softmax_cross_entropy_with_logits_v2(
                        labels=pen_data, logits=z_pen_logits)
      result2 = tf.reshape(result2, [-1, 1])
      result2 = multiply([result2, fs])

      result = result1 + result2
      return result

    # reshape target data so that it is compatible with prediction shape
    target = tf.reshape(inputs, [-1, 5])
    [x1_data, x2_data, eos_data, eoc_data, cont_data] = tf.split(target, 5, 1)
    pen_data = concatenate([eos_data, eoc_data, cont_data], 1)

    out = get_mixture_coef(outputs, self.n_out)
    [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, o_pen_logits] = out

    lossfunc = get_lossfunc(o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr,
                            o_pen_logits, x1_data, x2_data, pen_data)

    self.r_loss = tf.reduce_mean(lossfunc)

    kl_weight = self.hps.kl_weight_start
    self.loss = self.r_loss + self.kl_loss * kl_weight
    return self.loss

  def model_compile(self, model):
      adam = Adam(lr=self.hps.learning_rate, clipvalue=self.hps.grad_clip)
      model.compile(loss=self.vae_loss, optimizer=adam)

以下は学習後のモデルからスケッチを生成するサンプリングのコードです.

def sample(model, hps, weights, seq_len=250, temperature=1.0,
            greedy_mode=False, z=None):
  """Samples a sequence from a pre-trained model."""

  def adjust_temp(pi_pdf, temp):
    pi_pdf = np.log(pi_pdf) / temp
    pi_pdf -= pi_pdf.max()
    pi_pdf = np.exp(pi_pdf)
    pi_pdf /=pi_pdf.sum()
    return pi_pdf

  def get_pi_idx(x, pdf, temp=1.0, greedy=False):
    """Samples from a pdf, optionally greedily."""
    if greedy:
      return np.argmax(pdf)
    pdf = adjust_temp(np.copy(pdf), temp)
    accumulate = 0
    for i in range(0, pdf.size):
      accumulate += pdf[i]
      if accumulate >= x:
        return i
    tf.logging.info('Error with smpling ensemble.')
    return -1

  def sample_gaussian_2d(mu1, mu2, s1, s2, rho, temp=1.0, greedy=False):
    if greedy:
      return mu1, mu2
    mean = [mu1, mu2]
    s1 *= temp * temp
    s2 *= temp * temp
    cov = [[s1 * s1, rho * s1 * s2], [rho * s1 * s2, s2 * s2]]
    x = np.random.multivariate_normal(mean, cov, 1)
    return x[0][0], x[0][1]

  # load model
  model.sketch_rnn_model.load_weights(weights)

  prev_x = np.zeros((1, 1, 5), dtype=np.float32)
  prev_x[0, 0, 2] = 1 # initially, we want to see beginning of new stroke
  if z is None:
    z = np.random.randn(1, hps.z_size)

  batch_z = Input(shape=(hps.z_size,)) # (1, z_size)
  initial_state = model.initial_state_layer(batch_z)
  # (1, dec_rnn_size * 2)

  decoder_input = Input(shape=(1, 5)) # (1, 1, 5)
  overlay_x = RepeatVector(1)(batch_z) # (1,1, z_size)
  actual_input_x = concatenate([decoder_input, overlay_x], 2)
  # (1, 1, 5 + z_size)

  decoder_h_input = Input(shape=(hps.dec_rnn_size, ))
  decoder_c_input = Input(shape=(hps.dec_rnn_size, ))
  output, last_h, last_c = model.decoder_lstm(
                        actual_input_x,
                        initial_state=[decoder_h_input, decoder_c_input])
  # [(1, 1, dec_rnn_size), (1, dec_rnn_size), (1, dec_rnn_size)]
  output = model.output_layer(output)
  # (1, 1, n_out)

  decoder_initial_model = Model(batch_z, initial_state)
  decoder_model = Model([decoder_input, batch_z,
                        decoder_h_input, decoder_c_input],
                        [output, last_h, last_c])

  prev_state = decoder_initial_model.predict(z)
  prev_h, prev_c = np.split(prev_state, 2, 1)
  # (1, dec_rnn_size), (1, dec_rnn_size)

  strokes = np.zeros((seq_len, 5), dtype=np.float32)
  greedy = False
  temp =  1.0

  for i in range(seq_len):
    decoder_output, next_h, next_c = decoder_model.predict(
                                          [prev_x, z, prev_h, prev_c])
    out = sketch_rnn_model.get_mixture_coef(decoder_output, model.n_out)
    [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, o_pen_logits] = out

    o_pi = K.eval(o_pi)
    o_mu1 = K.eval(o_mu1)
    o_mu2 = K.eval(o_mu2)
    o_sigma1 = K.eval(o_sigma1)
    o_sigma2 = K.eval(o_sigma2)
    o_corr = K.eval(o_corr)
    o_pen = K.eval(o_pen)

    if i < 0:
      greedy = False
      temp = 1.0
    else:
      greedy = greedy_mode
      temp = temperature

    idx = get_pi_idx(random.random(), o_pi[0], temp, greedy)

    idx_eos = get_pi_idx(random.random(), o_pen[0], temp, greedy)
    eos=[0, 0, 0]
    eos[idx_eos] = 1

    next_x1, next_x2 = sample_gaussian_2d(o_mu1[0][idx], o_mu2[0][idx],
                                          o_sigma1[0][idx], o_sigma2[0][idx],
                                          o_corr[0][idx], np.sqrt(temp), greedy)

    strokes[i, :] = [next_x1, next_x2, eos[0], eos[1], eos[2]]

    prev_x = np.zeros((1, 1, 5), dtype=np.float32)
    prev_x[0][0] = np.array(
        [next_x1, next_x2, eos[0], eos[1], eos[2]], dtype=np.float32)
    prev_h, prev_c = next_h, next_c

  # delete model to avoid a memory leak
  K.clear_session()

  return strokes

結果

今回は時間の都合上フクロウのデータセットでのみモデルの学習を行いました. それではまず入力に用いるスケッチをテストセットからランダムに選んできて,どのようなスケッチか見てみましょう.ちなみにこれは人間がフクロウを描いたものです.

f:id:k_kawabata:20181029021841p:plain

正直フクロウっぽくないですが一応生き物っぽいので良しとします.

次にこのスケッチからエンコーダーによって潜在ベクトルを得ます.

f:id:k_kawabata:20181029022225p:plain

そして最後にデコーダーによってスケッチを生成します.

f:id:k_kawabata:20181029022427p:plain

なかなかフクロウっぽいのではないでしょうか?

それでは次に様々なtemperatureパラメータを用いた時にスケッチがどのように変化していくか見ていきましょう.

f:id:k_kawabata:20181029022849p:plain

右にいくほどtemperatureの値は大きくなります.つまり,よりランダムになっていきます.temperatureが0.1の時はどちらかと言うとペンギンぽいですが,temperatureが0.3と0.5の時はかなりフクロウっぽいスケッチになっています.

かなり荒削りな部分もありますが,スケッチとして認識できるレベルまでしっかり学習ができていることがわかります.ただ,一つ気になったこととしては入力画像にかかわらず常にペンギンのようなスケッチを生成してしまっていたことです.ここの例でも示している通り,入力をかなり無視してペンギンのようなスケッチを生成しています.論文著者のコードではRNNセルにLayer Normalization付きのLSTMやHyperLSTMを用いることができるようになっており,またKL Lossのアニーリングも行なっていたのですが今回の実装ではそれらは含まれていなかったためこのような結果になったのではないかと考えています.入力画像に関係なく毎回同じようなスケッチを生成することに関しては,特にKL Lossのアニーリングを行なっていないことで,Reconstruction Lossに比べてKL Lossにばかり重点が置かれたことが原因だと考えています.

参照

最後に

尚,エクサウィザーズは20卒向けのAIエンジニアのポジションで,内定直結型インターンを東京・京都オフィスで募集しています.ご興味を持たれた方はぜひご応募ください.

ご応募はこちらから