エクサウィザーズ Engineer Blog

株式会社エクサウィザーズのエンジニアチームブログ

いつでも心電図が取れたとして? CES 2019 レポート 3日目

こんにちは。ロボットチーム、ソフトウェアエンジニアの鳥居です。

昨日に引き続き、CES 2019のレポートです。エンジニア視点で面白かったものを紹介します。

ウェアラブルデバイスによる医療データの収集

f:id:eitaro-torii-exwzd:20190111184142j:plain
バイオメトリックセンサが入った様々なウェアラブルデバイス (Valencell)

医療に関するセッションや医療・健康に関するプロダクトの展示も多数ありました。その中の1つで「ヒアラブル(hearable)」デバイスの話が面白かったので紹介します。

ヒアラブルデバイスはウェアラブルデバイスの1種で、音楽を聴くために身に付けるもののことを指します。ヒアラブルデバイスは肌に直接触れるので、センサーを取り付けることで健康状態に影響する様々な情報を取得することができます。このようなデバイスは年々増えており、また取得されるデータも検査のタイミングだけでなく身につけている限り連続的に取得できるようになると思われます。より質の良いデータを得ることで、医師や患者のコミュニケーションが円滑になったり、患者自身がより自分の身体を知ることが期待されます。一方、データの増加に伴って重要になるのは「必要な情報をいかに絞り込むか」という点ですが、これは今後発展していく分野と考えられています。

ヒアラブルデバイスが特徴的なのは、このようなウェアラブルデバイスの中でも多くの割合を占めていることと、他にも活用の場面があると考えられていることです。耳に埋め込むデバイスはその内側から「本人が喋った声」を抽出することができるので音声入力のデバイスとして使用したり、聞き分けたい音を動的に切り替えるような「Augmented hearing」などの方向性が考えられています。

写真は Valencell という会社がセンサを提供しているウェアラブルデバイスを集めた展示のものです。既に多くの種類のデバイスで自身の健康状態が把握できるようになっています。

オピオイド問題への解決策

f:id:eitaro-torii-exwzd:20190111183902j:plain
神経痛と対応する脊椎の位置を示す展示 (NANS)

アメリカでは現在、麻薬系鎮痛剤を多量に処方されてしまうことで依存症や過剰摂取による死亡事故が発生しており、この問題はオピオイド問題と呼ばれています。オピオイドは特に慢性の痛みに対して処方されることが多く、オピオイドを使わない方法として神経に電気的な刺激を与えることで慢性の痛みを軽減するアプローチが提案されています。写真は神経痛を抑えるために脊椎のどの位置を刺激すればよいかを表したモデルで、この位置に電気的な刺激を与えるデバイスを埋め込んで薬の代わりに電気刺激で痛みを和らげることができます。アメリカでのみ問題となっていることですが、技術で社会的な問題に挑戦する姿勢は素晴らしいと思います。

エクサウィザーズでは医療の問題を解決したいエンジニアを募集しています。ご興味を持たれた方はぜひご応募ください!採用情報|株式会社エクサウィザーズ

また、ExaWizards Engineer Blogでは、定期的にAIなどの技術情報を発信していきます。Twitter (https://twitter.com/BlogExawizards) で更新情報を配信していきますので、ぜひフォローをよろしくお願いします!

5Gの可能性とは CES 2019 レポート 2日目

こんにちは。ロボットチーム、ソフトウェアエンジニアの鳥居です。

昨日に引き続き、CES 2019のレポートです。エンジニア視点で面白かったものを紹介します。

5Gとエッジコンピューティング

CESでは展示の他に、業界のリーダーが集まって特定のトピックについて話をするイベントが多数開催されています。5Gとエッジコンピューティングについてのセッションが面白そうだったので聞きに行ってみました。

興味深かった話をざっくり箇条書きで:

  • 5Gによる通信が一般で使われるようになると「どこで計算するか」の依存性が薄くなる
    • レイテンシが小さくなりスループットが上がるので、データの受け渡しがすぐにできるようになる
    • 例えば、現在の AR/VR用のヘッドマウントディスプレイは本体に計算処理を行うモジュールが含まれるのでサイズが大きめだが、このモジュールを外して小さくすることができる
  • 計算のトレンドは振り子のように揺れている
    • (1980年代)メインフレーム,集中型 -> (1990年代) PC, 分散型 -> (2000年代) クラウド, 集中型 -> (20XX年代) エッジコンピューティング, 分散型
    • 消費者が持っている計算リソースが年々増加しており、高速なlast 1 mileの通信の登場によってエッジコンピューティングのメリットが大きくなる。

IBMの量子コンピューター

昨日のキーノートで、https://newsroom.ibm.com/2019-01-08-IBM-Unveils-Worlds-First-Integrated-Quantum-Computing-System-for-Commercial-Use#assets_all:IBMが商用量子コンピュータについてのアナウンスがありました。IBMのブースに模型が展示されてますので行ってみました。

f:id:eitaro-torii-exwzd:20190110134721j:plain
IBM Q System One の模型

研究室の外で動く商用量子コンピュータは世界初ということで、次の世代のコンピューティングを始めようとする意気込みが伝わってきました。

エクサウィザーズでは新しい時代のエンジニアリングに挑戦したいエンジニアを募集しています。ご興味を持たれた方はぜひご応募ください!採用情報|株式会社エクサウィザーズ

また、ExaWizards Engineer Blogでは、定期的にAIなどの技術情報を発信していきます。Twitter (https://twitter.com/BlogExawizards) で更新情報を配信していきますので、ぜひフォローをよろしくお願いします!

次はVPUの時代が来る? CES 2019 レポート 1日目

こんにちは。ロボットチーム、ソフトウェアエンジニアの鳥居です。

新しいプロダクトのアイデアを得るために、ロボットチーム3名でコンシューマー・エレクトロニクス・ショー(CES) というイベントに参加しています。アメリカのラスベガスで毎年開催されているイベントで、今年の一般公開は1/8〜1/11の日程となっています。

この記事では、エンジニア視点で面白かったものをピックアップして紹介していきたいと思います。

AIセキュリティカメラ、VPU

f:id:eitaro-torii-exwzd:20190109105311j:plain
CES 2019 Innovation Awardを受賞した SimCam

家庭用の監視カメラですが、顔識別やネットワークの機能が備わっているものです。画像の識別の機能はクラウドベースで提供されているものが多いですが、このプロダクトは小さなカメラの中にVisual Processing Unit (VPU)を搭載しているため、クラウドへのデータアップロードが不要となっています。クラウド利用のコストが不要になることやプライバシー保護の観点で優れた特徴です。Google Assistant や Amazon Alexaとも連携ができ、その場合でも画像をアップロードすることなく動作することができます。展示では「知らない人が玄関に近づいたらAlexaが声を掛ける」という使用例が出ていました。

面白いなと思ったのは、既存ハードウェアでは解決できてない問題を新しい種類のハードウェアを使って解決しに行っているところです。AIを使った計算というとGPUを使うものと思ってしまいますが、識別を主に行うプロダクトにはGPUは過剰なリソースとなり消費電力などの問題を抱えてしまいます。VPUはカメラから取得した画像を処理するのに特化した計算機で、画像向けの機械学習の計算に適した構造のハードウェアとのことです。カメラの中にVPUを積むアイデアがいくつもの問題を解決していてプロダクトとしての魅力を高めていることにとても面白いと思いました。

展示会場ではVPUというキーワードを他にも聞くことができました。画像処理AIの消費者向けプロダクトは、VPUがキーワードになるのかもしれません。

オムロンのパレタイジングと卓球

f:id:eitaro-torii-exwzd:20190109113313j:plain
卓球ロボット (オムロン)

f:id:eitaro-torii-exwzd:20190109112726j:plain
パレタイズ (オムロン)

ロボットの展示コーナーで一番最初に見えてくるのがオムロンの卓球ロボットでした。ロボットアームの先にラケットが付いていて、人がロボットと対戦すると対戦終了後に「ここをこうするともっと上手くなる」というチャートが表示されるという展示でした。その横で同じロボットを使ったパレタイジングのデモがあるのですが、パレタイジングがとても速い!目を引く卓球の展示と、単体だとちょっと地味なパレタイジングのデモを同じロボットでやっているのをみて、見せ方が面白さを感じました。

エクサウィザーズでは新しいプロダクトを作ることに興味のあるエンジニアを募集しています。ご興味を持たれた方はぜひご応募ください!採用情報|株式会社エクサウィザーズ

また、ExaWizards Engineer Blogでは、定期的にAIなどの技術情報を発信していきます。Twitter (https://twitter.com/BlogExawizards) で更新情報を配信していきますので、ぜひフォローをよろしくお願いします!

ブラウザで動くAI類語地図

こんにちは。エクサウィザーズAIエンジニアの須藤です。

類語辞典って便利ですよね。 書いた文章がしっくりこないときに、ニュアンスの違う単語に置き換えたり、和語と漢語と外来語を入れ替えたりできます。

しかし、適当な表現が出てこなくて、もっと漠然と言葉を探したい時はどうしましょう。 辞書を繰り返し引いて、類語の類語の類語を見て回るのはちょっと面倒ですね。 かと言って項目あたりの類語数が増えたとしても、その中で探す手間が増えて、やはり使いづらそうです。

そこで、こういうものを作ってみました。 f:id:kentaro-suto:20181128154227p:plain https://base.exawizards.com/view/modelDetail?id=44

詳細を以下で解説します。

概要

類語を平面に分散して表示するプログラムです。 JavaScriptで書かれていますので、最近のブラウザがあれば動きます。 類語は見出し語に似ているものほど中央に、かつ類語同士でも似たものほど近くに表示されるようになっています。 つまり、よりしっくりくる単語の周りを優先して探せば、素早く目的の単語を見つけられることになります。 ...ということを期待して作りました。

単語の位置は、Word2vecという手法で得られた意味ベクトルから計算して求めています。 ここで意味ベクトル*1とは、単語の意味を数値化したものです。 関連の深い単語同士はベクトルの値も近くなります。 実際のテキストから、後述する機械学習によって得られます。

学習

  1. 前処理
    • テキストを形態素解析エンジンであるMeCabで分解します。
    • 単語ごとに登場頻度を調べ、高い順に番号をつけます。
    • テキストを単語番号の列に変換します。
  2. 学習
    • 単語数とベクトルの次元数を決めておきます。
    • 前後の単語が与えられたときに真ん中の単語がどれかを選択肢の中から予測する、以下の学習モデル*2を構築します。 f:id:kentaro-suto:20181122170033p:plain
    • 予測自体に有用性は特にありません。欲しいのは副産物である重みパラメータの値です。
    • 下記の設定でソースごとに学習を行いました。
ソース ウィキペディア 青空文庫 小説家になろう*3
MeCab辞書 NEologd デフォルト NEologd
テキスト量 約6億語 約5000万語 約2億5000万語
全ての単語の種類 182,669 108,676 227,685
学習した単語の種類 100,000 100,000 100,000
ベクトル次元 300 300 300
学習回数*4 5億 2.8億 6.3億
  • 後処理
    • 重みWは単語数×ベクトルの次元を持った行列です。これがそのまま各単語の意味ベクトルを表します。
    • 単語とベクトルの関係をオブジェクトリテラル形式に書き出し、可視化プログラムで使えるようにします。

可視化

以下の方針で単語の表示位置を決定します。

  • 見出し語を画面の中心に配置する。
  • 見出し語に似た単語ほど、中央近くに配置する。
  • 似ていない単語同士が、できるだけ離れるようにする。

具体的には以下のような計算を行います。

  • 全ての見出し語の意味ベクトルを調べ\{a_i\}とし、その平均aを求めます。これが画面の中心にあたります。
  • aに近いベクトルを30個選び\{b_i\}とし、その平均bも求めます。これらが分布している主な方向を以下で求めます。
  • a-bと直交し分散 \overline{(\xi\cdot(b_i-b))^2}を最大にする方向ベクトル\xiを求めます。これがX方向になります。
  • a-bおよび\xiと直交し分散 \overline{(\psi\cdot(b_i-b))^2}を最大にする別の方向ベクトル\psiを求めます。これがY方向になります。
  • 意味ベクトルがv_iになる単語の画面中心からの距離をr={\sigma}\left(\frac{||a|| ||v_i||}{a\cdot{v_i}}\right)^6とします。式の形および指数6に深い意味はありません。色々試した中で表示結果が見やすかったものを選びました。{\sigma}はスケールで、\{b_i\}に対応する単語が画面にちょうど収まるように調節します。
  • 画面中心からの方向ベクトルを(\hat x_i,\hat y_i)=\frac{(\xi,\psi)\cdot(v_i-b)}{||(\xi,\psi)\cdot(v_i-b)||}で求めます。
  • 最後に座標を、(x_i,y_i)=r (\hat x_i,\hat y_i)+{\mu}で求めます。{\mu}は画面の中心の座標です。

使い方

  1. ソースの選択
    • ポップアップボタンで学習ソースを選択してください。
    • 選択すると、データの読み込みが始まります。データファイルは10〜20MBあるので、読み込みには時間がかかります。
  2. 見出し語の入力
    • テキストフィールドに単語を入力してください。
    • 複数の単語を入力できます。区切り文字が無くても、自動的に知っている単語に分解します。スペースなどで明示的に分解もできます。
    • 自立語*5が対象です。動詞と形容詞は終止形、形容動詞は語幹にしてください*6
  3. 画面の操作
    • 画面をドラッグするとスクロールできます。
    • 右下のスライダーを操作すると、ズームできます。
    • 単語をダブルクリックすると、その単語を新たな見出し語として再描画します。
  4. その他
    • 学習ソースと見出し語はアドレスバーに逐次反映されます。URLをコピーして別のブラウザで開くと、状態が再現されます。

使用例

ソースの違い

学習ソースを変えると、異なる観点で類語を検索することができます。

ウィキペディア 青空文庫 小説家になろう
f:id:kentaro-suto:20181122141733p:plain f:id:kentaro-suto:20181122141757p:plain f:id:kentaro-suto:20181122141808p:plain
妙に偏った選別になりました。題名などで繰り返し使われると影響を受けやすいようです。 犬と猫は使われる文脈が似ているため、類語あつかいになってしまいます。 様々な猫の呼び方が出てきます。
タンク f:id:kentaro-suto:20181122143214p:plain f:id:kentaro-suto:20181122143223p:plain f:id:kentaro-suto:20181122143305p:plain
水槽関連、燃料関連に分かれました。 燃料関連、戦車関連です。 タンク職関連が大勢を占めます。
四天王 f:id:kentaro-suto:20181122150948p:plain f:id:kentaro-suto:20181122150956p:plain f:id:kentaro-suto:20181122151004p:plain
比喩的な使われ方のほか、元祖四天王もカバーします。 仏教関連のほか、称号の例として頼光四天王が出てきました。 魔王の部下として倒される存在です。

興味深い例

制作の過程で気がついた、興味深い例を紹介します。 学習結果は随時アップデートしているので、現在のデータで試してもこの通りにならないことをご了承ください。

東京 京都 ローマ
f:id:kentaro-suto:20181122153111p:plain f:id:kentaro-suto:20181122153459p:plain f:id:kentaro-suto:20181122153729p:plain
東京が付く言葉、東京都の下部組織、東京の地名などに分かれます。 京都府の地名、京都市の地名、観光地に分かれます。 イタリアの各都市、古代ローマの詳細、他の文明などが出てきます。
家康 正義 しかし
f:id:kentaro-suto:20181122153909p:plain f:id:kentaro-suto:20181122154041p:plain f:id:kentaro-suto:20181122154322p:plain
戦国時代の家康、江戸時代の家康、安土桃山時代の家康に別れました。 正しい行いという意味の言葉と、名前が正義の人の苗字が出てきます。 接続詞なども入っています。

最後に

Word2vecの二次元可視化によって、言葉探しが楽しく便利に行えるようになる可能性を示しました。 こちらですぐ試せます。

尚、エクサウィザーズは優秀なエンジニア、社会課題を一緒に解決してくれる魔法使い”ウィザーズ”を募集しています。ご興味を持たれた方はぜひご応募ください。 採用情報|株式会社エクサウィザーズ

ExaWizards Engineer Blogでは、定期的にAIなどの技術情報を発信していきます。Twitter (https://twitter.com/BlogExawizards) で更新情報を配信していきますので、ぜひフォローをよろしくお願いします!

*1:埋め込みベクトルまたは分散表現とも言います

*2:Negative Samplingという技術の変形です

*3:「小説家になろう」は 株式会社ヒナプロジェクトの登録商標です

*4:バッチサイズ×ステップ数×エポック数

*5:助詞と助動詞以外

*6:例:長く→長い、のびろ→のびる、毛むくじゃらな→毛むくじゃら

Sketch-RNN でスケッチの自動生成(VAE + LSTM)

こんにちは.エクサウィザーズでインターンをしている川畑です.

視覚によるコミュニケーションというのは人々が相手に何らかのアイデアを伝える際に鍵となります.私たちは小さい頃から物体を描く力を養ってきており,時には感情までもたった複数の線で表現することも可能です.こうした単純な絵というのは,身の回りのものを写真のように捉え忠実に再現したものではなく,どのようにして人間が物体の特徴を認識しそれらを再現するか,ということを教えてくれます.

そこで今回はSketch-RNNと呼ばれるRecurrent Neural Networkモデルでのスケッチの自動生成に取り組んでみました. このモデルは人間がするのと同じように抽象的な概念を一般化し,スケッチを生成することを目的としたものです.このモデルに関しては今の所具体的なアプリケーションが存在するというわけではなく,機械学習がどのようにクリエイティブな分野で活用できるか,という一例を提案したものになります.ソースコードはこちらからダウンロードできます.

f:id:k_kawabata:20181027163327p:plain [1] Google AI Blog: Teaching Machines to Draw

データセット

今回は論文でも用いられていたQuick, Draw!のデータセットを使用しました.これは20秒以内であるお題の絵を描くゲームで,データセットとしてネコやブタ,バスなど数百個のクラスのデータを公開しています.各クラスのデータは,訓練,検証,テストとしてそれぞれ70000,2500,2500のデータに分かれています.

Sketch-RNNモデルを学習させるにあたり,データのフォーマットとしては各点のデータが5つの要素からなるベクトル \left(\Delta x, \Delta y, p_{1}, p_{2}, p_{3} \right)を使用しました.最初の二つの要素は,前の点からの x, yの変位,残りの三つの要素はone-hotベクトルとなっています.一つ目のペン状態 p_{1}はペンが現在紙に接していて,次の点まで線が描かれることを示しており,二つ目のペン状態 p_{2}はそこでペンが紙から離れ,次の点まで線は引かれないことを示しており,三つ目のペン状態 p_{3}はスケッチが終了することを示しています.

Sketch-RNNモデル

それではモデルについて解説していきたいと思います.

[2] A Neural Representation of Sketch Drawings

f:id:k_kawabata:20181027164205p:plain

モデルの大枠はSequence-to-Sequence Variational Autoencoder (VAE) でできています.まず,エンコーダーであるbidirectional RNN (今回は単純なLSTMを使用)にスケッチのシーケンスを入力し,出力として潜在ベクトル zを得ます.具体的には,双方向のRNNから得られた隠れ状態を連結し,全結合層によって連結された隠れ状態から \muおよび \hat{\sigma}を得ます. \hat{\sigma}に関しては,負にならないように \sigma = \exp \left ({\frac{\hat{\sigma}}{2}} \right)の操作を加えます.そしてさらにガウス分布 \mathcal{N}(0, I)と組み合わせることで最終的に潜在ベクトルを計算することができます.式で書くと以下のようになります.

 \displaystyle \mu = W_{\mu}h + b_{\mu}, \;  \hat{\sigma} = W_{\sigma}h + b_{\sigma}, \; \sigma = \exp \left ({\frac{\hat{\sigma}}{2}} \right), \; z = \mu + \sigma\odot\mathcal{N}(0, I)

潜在ベクトルが得られたら,次はデコーダーです.隠れ状態の初期値には潜在ベクトルに tanh関数を掛けたものを用います.

 \displaystyle \left[h_{0}; c_{0} \right] = tanh(W_{z}z + b_{z})

各ステップでのデコーダーの入力には前ステップでのストローク S_{i-1}と潜在ベクトル zを連結したものを与え, S_{0}にはスケッチの開始を意味する (0, 0, 1, 0, 0)を与えるようにします.各ステップでの出力は次のデータ点の確率分布に関するパラメータを返します.Sketch-RNNでは M個の正規分布からなるGaussian mixture model (GMM)により \left(\Delta x, \Delta y \right)を,カテゴリカル分布 \left(q_{1}, q_{2}, q_{3} \right) \left(q_{1} + q_{2} + q_{3} = 1 \right)により真値である \left(p_{1}, p_{2}, p_{3} \right)をモデル化します.

 \displaystyle p(\Delta x, \Delta y) = \sum_{j=1}^{M} \Pi _{j} \mathcal{N} \left(\Delta x, \Delta y \mid \mu_{x, j}, \mu_{y, j}, \sigma_{x, j}, \sigma_{y, j}, \rho_{xy, j} \right), where \sum_{j=1}^{M} \Pi_{j} = 1

上の式において \mathcal{N} \left(x, y \mid \mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy} \right)は2変量正規分布を表しており,5つのパラメータ \left(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy} \right)で構成されます. \mu_{x}, \mu_{y}は平均,  \sigma_{x}, \sigma_{y}は標準偏差, \rho_{xy} x, yの相関係数です.また, \Piは長さ Mのベクトルであり,Gaussian mixture modelの各正規分布の重みに相当します.ですので,出力のサイズは3つのロジット \left(q_{1}, q_{2}, q_{3} \right)も含めて合計で 5M + M + 3となります.

 \displaystyle x_{i} = \left[S_{i-1}; z \right], \left[h_{i}; c_{i} \right]

 y_{i} = W_{y}h_{i} + b_{y}, \; y_{i} \in \mathbb{R}^{6M+3}

ここで,出力ベクトル y_{i}は次のように分解されます.

 \displaystyle \left[ (\hat{\Pi}_{1} \, \mu_{x} \, \mu_{y} \, \hat{\sigma}_{x} \, \hat{\sigma}_{y} \,  \hat{\rho}_{xy})_{1} \ldots (\hat{\Pi}_{M} \, \mu_{x} \, \mu_{y} \, \hat{\sigma}_{x} \, \hat{\sigma}_{y} \,  \hat{\rho}_{xy})_{M} \; (\hat{q}_{1} \, \hat{q}_{2} \, \hat{q}_{3})  \right] = y_{i}

標準偏差,相関係数に関してはそれぞれ負にならないように,また-1から1の値を取るように exp tanhによる操作を施します.

 \displaystyle \sigma_{x} = \exp\left(\hat{\sigma}_{x} \right), \; \sigma_{y} = \exp\left(\hat{\sigma}_{y} \right), \; \rho_{xy} = tanh\left(\hat{\rho}_{xy}  \right)

カテゴリカル分布に関してはSoftmax関数を適応させ,全ての値が0〜1に収まるようにします.

 \displaystyle q_{k} = \frac{\exp(\hat{q}_{k})}{\sum_{j=1}^{3} \exp(\hat{q}_{j})}, \; k \in \{1, 2, 3\}, \; \Pi_{k} = \frac{\exp(\hat{\Pi}_{k})}{\sum_{j=1}^{M} \exp(\hat{\Pi}_{j})}, \; k \in \{1, \ldots , M\}

損失関数

損失関数はVAEと同じでReconstruction Loss,  L_{R}とKullback-Leibler (KL) Divergence Loss,  L_{KL}の二つの項からなります.

  • Reconstruction Loss,  L_{R}

Reconstruction Loss項は訓練データ Sを説明する確率分布の対数尤度を表しており,これを最大化するように学習します.そして,Reconstruction Loss,  L_{R}はさらに座標に関する項 L_{s}とペン状態に関する項 L_{p} からなっており,それぞれ 以下のように書くことができます.

 \displaystyle L_{s} = -\frac{1}{N_{max}} \sum_{i=1}^{N_{s}}log \left(\sum_{j=1}^{M}\Pi_{j, i} \mathcal{N} \left(\Delta x_{i}, \Delta y_{i} \mid \mu_{x, j, i}, \mu_{y, j, i}, \sigma_{x, j, i}, \sigma_{y, j, i}, \rho_{xy, j, i} \right)\right)

 \displaystyle L_{p} = -\frac{1}{N_{max}} \sum_{i=1}^{N_{max}} \sum_{k=1}^{3} p_{k, i} log\left(q_{k, i} \right)

 L_{R} = L_{s} + L_{p}

  • Kullback-Leibler (KL) Divergence Loss,  L_{KL}

KL Divergence Loss項は潜在ベクトル zが標準正規分布からどれだけ離れているかを表しており,これを最小化するように学習することになります.

 \displaystyle L_{KL} = -\frac{1}{2N_{z}} \left(1 + \hat{\sigma} - \mu^{2} - \exp(\hat{\sigma}) \right)

実際に最適化する損失関数には L_{R} L_{KL}を重み付けして足し合わせたものを用います.

 Loss = L_{R} + w_{KL}L_{KL}

それぞれの損失項にはトレードオフの関係があり, w_{KL} \rightarrow 0の時にはモデルは純粋なオートエンコーダーに近づき,より良いReconstruction Lossを得ることができる一方で,潜在空間における事前分布の強化を犠牲にすることになります.

サンプリング

モデルを学習させた後はいよいよスケッチの生成です.サンプリング過程では各ステップごとにGMMとカテゴリカル分布のパラメータを得,そのステップでの出力 S_{i}'を得ます.訓練過程とは異なり,サンプリング過程では出力 S_{i}'を次のステップの入力とし,このサンプリングステップを p_{3} =1となるかステップ数が N_{max}となるまで繰り返していきます.最終的に出力を得る際にはdeterministicに確率密度関数で最も確率の高い点を選ぶのではなく,下記のようにtemperatureパラメータ \tauを導入することで出力のランダムさを調節できるようにしています. \tauは0〜1の値を取り, \tau = 0の時にはモデルはdeterministicになります.

 \displaystyle \hat{q}_{k} \rightarrow \frac{\hat{q}_{k}}{\tau}, \; \hat{\Pi}_{k} \rightarrow \frac{\hat{\Pi}_{k}}{\tau}, \; \sigma_{x}^{2} \rightarrow \sigma_{x}^{2}\tau, \; \sigma_{y}^{2} \rightarrow \sigma_{y}^{2}\tau

コード

モデルに関する部分のコードを示します.Githubに論文著者によるオリジナルのコードもありますが,オリジナルではtensorflowで書かれていたものをkerasで書き換えました.全コードはこちらからダウンロードできます.

# below is where we need to do MDN (Mixture Density Network) splitting of
# distribution params
def get_mixture_coef(output, n_out):
  """Returns the tf slices containing mdn dist params."""
  # This uses eqns 18 -> 23 of http://arxiv.org/abs/1308.0850.
  z = output
  z = tf.reshape(z, [-1, n_out])
  z_pen_logits = z[:, 0:3]  # pen states
  z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = tf.split(z[:, 3:], 6, 1)

  # process output z's into MDN paramters

  # softmax all the pi's and pen states:
  z_pi = tf.nn.softmax(z_pi)
  z_pen = tf.nn.softmax(z_pen_logits)

  # exponentiate the sigmas and also make corr between -1 and 1.
  z_sigma1 = K.exp(z_sigma1)
  z_sigma2 = K.exp(z_sigma2)
  z_corr = tf.tanh(z_corr)

  r = [z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen, z_pen_logits]
  return r


class SketchRNN():
  """SketchRNN model definition."""

  def __init__(self, hps):
    self.hps = hps # hps is hyper parameters
    self.build_model(hps)

  def build_model(self, hps):
    # VAE model = encoder + Decoder
    # build encoder model
    encoder_inputs = Input(shape=(hps.max_seq_len, 5), name='encoder_input')
    # (batch_size, max_seq_len, 5)
    encoder_lstm = LSTM(hps.enc_rnn_size,
                use_bias=True,
                recurrent_initializer='orthogonal',
                bias_initializer='zeros',
                recurrent_dropout=1.0-hps.recurrent_dropout_prob,
                return_sequences=True,
                return_state=True)
    bidirectional = Bidirectional(encoder_lstm)
    (unused_outputs, # (batch_size, max_seq_len, enc_rnn_size * 2)
    last_h_fw, unused_c_fw, # (batch_size, enc_rnn_size) * 2
    last_h_bw, unused_c_bw) = bidirectional(encoder_inputs)
    last_h = concatenate([last_h_fw, last_h_bw], 1)
    # (batch_size, enc_rnn_size*2)

    normal_init = RandomNormal(stddev=0.001)
    self.z_mean = Dense(hps.z_size,
                  activation='linear',
                  use_bias=True,
                  kernel_initializer=normal_init,
                  bias_initializer='zeros')(last_h) # (batch_size, z_size)
    self.z_presig = Dense(hps.z_size,
                  activation='linear',
                  use_bias=True,
                  kernel_initializer=normal_init,
                  bias_initializer='zeros')(last_h) # (batch_size, z_size)

    def sampling(args):
      z_mean, z_presig = args
      self.sigma = K.exp(0.5 * z_presig)
      batch = K.shape(z_mean)[0]
      dim = K.int_shape(z_mean)[1]
      epsilon = K.random_normal((batch, dim), 0.0, 1.0)
      batch_z = z_mean + self.sigma * epsilon

      return batch_z # (batch_size, z_size)

    self.batch_z = Lambda(sampling,
                    output_shape=(hps.z_size,))([self.z_mean, self.z_presig])

    # instantiate encoder model
    self.encoder = Model(
                    encoder_inputs,
                    [self.z_mean, self.z_presig, self.batch_z], name='encoder')
    # self.encoder.summary()

    # build decoder model
    # Number of outputs is 3 (one logit per pen state) plus 6 per mixture
    # component: mean_x, stdev_x, mean_y, stdev_y, correlation_xy, and the
    # mixture weight/probability (Pi_k)
    self.n_out = (3 + hps.num_mixture * 6)

    decoder_inputs = Input(shape=(hps.max_seq_len, 5), name='decoder_input')
    # (batch_size, max_seq_len, 5)
    overlay_x = RepeatVector(hps.max_seq_len)(self.batch_z)
    # (batch_size, max_seq_len, z_size)
    actual_input_x = concatenate([decoder_inputs, overlay_x], 2)
    # (batch_size, max_seq_len, 5 + z_size)

    self.initial_state_layer = Dense(hps.dec_rnn_size * 2,
            activation='tanh',
            use_bias=True,
            kernel_initializer=normal_init)
    initial_state = self.initial_state_layer(self.batch_z)
    # (batch_size, dec_rnn_size * 2)
    initial_h, initial_c = tf.split(initial_state, 2, 1)
    # (batch_size, dec_rnn_size), (batch_size, dec_rnn_size)
    self.decoder_lstm = LSTM(hps.dec_rnn_size,
            use_bias=True,
            recurrent_initializer='orthogonal',
            bias_initializer='zeros',
            recurrent_dropout=1.0-hps.recurrent_dropout_prob,
            return_sequences=True,
            return_state=True
            )

    output, last_h, last_c = self.decoder_lstm(
                          actual_input_x, initial_state=[initial_h, initial_c])
    # [(batch_size, max_seq_len, dec_rnn_size), ((batch_size, dec_rnn_size)*2)]
    self.output_layer = Dense(self.n_out, activation='linear', use_bias=True)
    output = self.output_layer(output)
    # (batch_size, max_seq_len, n_out)

    last_state = [last_h, last_c]
    self.final_state = last_state

    # instantiate SketchRNN model
    self.sketch_rnn_model = Model(
                  [encoder_inputs, decoder_inputs],
                  output,
                  name='sketch_rnn')
    # self.sketch_rnn_model.summary()

  def vae_loss(self, inputs, outputs):
    # KL loss
    kl_loss = 1 + self.z_presig - K.square(self.z_mean) - K.exp(self.z_presig)
    self.kl_loss = -0.5 * K.mean(K.sum(kl_loss, axis=-1))
    self.kl_loss = K.maximum(self.kl_loss, K.constant(self.hps.kl_tolerance))

    # the below are inner functions, not methods of Model
    def tf_2d_normal(x1, x2, mu1, mu2, s1, s2, rho):
      """Returns result of eq # 24 of http://arxiv.org/abs/1308.0850."""
      norm1 = subtract([x1, mu1])
      norm2 = subtract([x2, mu2])
      s1s2 = multiply([s1, s2])
      # eq 25
      z = (K.square(tf.divide(norm1, s1)) + K.square(tf.divide(norm2, s2)) -
           2 * tf.divide(multiply([rho, multiply([norm1, norm2])]), s1s2))
      neg_rho = 1 - K.square(rho)
      result = K.exp(tf.divide(-z, 2 * neg_rho))
      denom = 2 * np.pi * multiply([s1s2, K.sqrt(neg_rho)])
      result = tf.divide(result, denom)
      return result

    def get_lossfunc(z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr,
                     z_pen_logits, x1_data, x2_data, pen_data):
      """Returns a loss fn based on eq #26 of http://arxiv.org/abs/1308.0850."""
      # This represents the L_R only (i.e. does not include the KL loss term).

      result0 = tf_2d_normal(x1_data, x2_data, z_mu1, z_mu2, z_sigma1, z_sigma2,
                             z_corr)
      epsilon = 1e-6
      # result1 is the loss wrt pen offset (L_s in equation 9 of
      # https://arxiv.org/pdf/1704.03477.pdf)
      result1 = multiply([result0, z_pi])
      result1 = K.sum(result1, 1, keepdims=True)
      result1 = -K.log(result1 + epsilon)  # avoid log(0)

      fs = 1.0 - pen_data[:, 2]  # use training data for this
      fs = tf.reshape(fs, [-1, 1])
      # Zero out loss terms beyond N_s, the last actual stroke
      result1 = multiply([result1, fs])

      # result2: loss wrt pen state, (L_p in equation 9)
      result2 = tf.nn.softmax_cross_entropy_with_logits_v2(
                        labels=pen_data, logits=z_pen_logits)
      result2 = tf.reshape(result2, [-1, 1])
      result2 = multiply([result2, fs])

      result = result1 + result2
      return result

    # reshape target data so that it is compatible with prediction shape
    target = tf.reshape(inputs, [-1, 5])
    [x1_data, x2_data, eos_data, eoc_data, cont_data] = tf.split(target, 5, 1)
    pen_data = concatenate([eos_data, eoc_data, cont_data], 1)

    out = get_mixture_coef(outputs, self.n_out)
    [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, o_pen_logits] = out

    lossfunc = get_lossfunc(o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr,
                            o_pen_logits, x1_data, x2_data, pen_data)

    self.r_loss = tf.reduce_mean(lossfunc)

    kl_weight = self.hps.kl_weight_start
    self.loss = self.r_loss + self.kl_loss * kl_weight
    return self.loss

  def model_compile(self, model):
      adam = Adam(lr=self.hps.learning_rate, clipvalue=self.hps.grad_clip)
      model.compile(loss=self.vae_loss, optimizer=adam)

以下は学習後のモデルからスケッチを生成するサンプリングのコードです.

def sample(model, hps, weights, seq_len=250, temperature=1.0,
            greedy_mode=False, z=None):
  """Samples a sequence from a pre-trained model."""

  def adjust_temp(pi_pdf, temp):
    pi_pdf = np.log(pi_pdf) / temp
    pi_pdf -= pi_pdf.max()
    pi_pdf = np.exp(pi_pdf)
    pi_pdf /=pi_pdf.sum()
    return pi_pdf

  def get_pi_idx(x, pdf, temp=1.0, greedy=False):
    """Samples from a pdf, optionally greedily."""
    if greedy:
      return np.argmax(pdf)
    pdf = adjust_temp(np.copy(pdf), temp)
    accumulate = 0
    for i in range(0, pdf.size):
      accumulate += pdf[i]
      if accumulate >= x:
        return i
    tf.logging.info('Error with smpling ensemble.')
    return -1

  def sample_gaussian_2d(mu1, mu2, s1, s2, rho, temp=1.0, greedy=False):
    if greedy:
      return mu1, mu2
    mean = [mu1, mu2]
    s1 *= temp * temp
    s2 *= temp * temp
    cov = [[s1 * s1, rho * s1 * s2], [rho * s1 * s2, s2 * s2]]
    x = np.random.multivariate_normal(mean, cov, 1)
    return x[0][0], x[0][1]

  # load model
  model.sketch_rnn_model.load_weights(weights)

  prev_x = np.zeros((1, 1, 5), dtype=np.float32)
  prev_x[0, 0, 2] = 1 # initially, we want to see beginning of new stroke
  if z is None:
    z = np.random.randn(1, hps.z_size)

  batch_z = Input(shape=(hps.z_size,)) # (1, z_size)
  initial_state = model.initial_state_layer(batch_z)
  # (1, dec_rnn_size * 2)

  decoder_input = Input(shape=(1, 5)) # (1, 1, 5)
  overlay_x = RepeatVector(1)(batch_z) # (1,1, z_size)
  actual_input_x = concatenate([decoder_input, overlay_x], 2)
  # (1, 1, 5 + z_size)

  decoder_h_input = Input(shape=(hps.dec_rnn_size, ))
  decoder_c_input = Input(shape=(hps.dec_rnn_size, ))
  output, last_h, last_c = model.decoder_lstm(
                        actual_input_x,
                        initial_state=[decoder_h_input, decoder_c_input])
  # [(1, 1, dec_rnn_size), (1, dec_rnn_size), (1, dec_rnn_size)]
  output = model.output_layer(output)
  # (1, 1, n_out)

  decoder_initial_model = Model(batch_z, initial_state)
  decoder_model = Model([decoder_input, batch_z,
                        decoder_h_input, decoder_c_input],
                        [output, last_h, last_c])

  prev_state = decoder_initial_model.predict(z)
  prev_h, prev_c = np.split(prev_state, 2, 1)
  # (1, dec_rnn_size), (1, dec_rnn_size)

  strokes = np.zeros((seq_len, 5), dtype=np.float32)
  greedy = False
  temp =  1.0

  for i in range(seq_len):
    decoder_output, next_h, next_c = decoder_model.predict(
                                          [prev_x, z, prev_h, prev_c])
    out = sketch_rnn_model.get_mixture_coef(decoder_output, model.n_out)
    [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, o_pen_logits] = out

    o_pi = K.eval(o_pi)
    o_mu1 = K.eval(o_mu1)
    o_mu2 = K.eval(o_mu2)
    o_sigma1 = K.eval(o_sigma1)
    o_sigma2 = K.eval(o_sigma2)
    o_corr = K.eval(o_corr)
    o_pen = K.eval(o_pen)

    if i < 0:
      greedy = False
      temp = 1.0
    else:
      greedy = greedy_mode
      temp = temperature

    idx = get_pi_idx(random.random(), o_pi[0], temp, greedy)

    idx_eos = get_pi_idx(random.random(), o_pen[0], temp, greedy)
    eos=[0, 0, 0]
    eos[idx_eos] = 1

    next_x1, next_x2 = sample_gaussian_2d(o_mu1[0][idx], o_mu2[0][idx],
                                          o_sigma1[0][idx], o_sigma2[0][idx],
                                          o_corr[0][idx], np.sqrt(temp), greedy)

    strokes[i, :] = [next_x1, next_x2, eos[0], eos[1], eos[2]]

    prev_x = np.zeros((1, 1, 5), dtype=np.float32)
    prev_x[0][0] = np.array(
        [next_x1, next_x2, eos[0], eos[1], eos[2]], dtype=np.float32)
    prev_h, prev_c = next_h, next_c

  # delete model to avoid a memory leak
  K.clear_session()

  return strokes

結果

今回は時間の都合上フクロウのデータセットでのみモデルの学習を行いました. それではまず入力に用いるスケッチをテストセットからランダムに選んできて,どのようなスケッチか見てみましょう.ちなみにこれは人間がフクロウを描いたものです.

f:id:k_kawabata:20181029021841p:plain

正直フクロウっぽくないですが一応生き物っぽいので良しとします.

次にこのスケッチからエンコーダーによって潜在ベクトルを得ます.

f:id:k_kawabata:20181029022225p:plain

そして最後にデコーダーによってスケッチを生成します.

f:id:k_kawabata:20181029022427p:plain

なかなかフクロウっぽいのではないでしょうか?

それでは次に様々なtemperatureパラメータを用いた時にスケッチがどのように変化していくか見ていきましょう.

f:id:k_kawabata:20181029022849p:plain

右にいくほどtemperatureの値は大きくなります.つまり,よりランダムになっていきます.temperatureが0.1の時はどちらかと言うとペンギンぽいですが,temperatureが0.3と0.5の時はかなりフクロウっぽいスケッチになっています.

かなり荒削りな部分もありますが,スケッチとして認識できるレベルまでしっかり学習ができていることがわかります.ただ,一つ気になったこととしては入力画像にかかわらず常にペンギンのようなスケッチを生成してしまっていたことです.ここの例でも示している通り,入力をかなり無視してペンギンのようなスケッチを生成しています.論文著者のコードではRNNセルにLayer Normalization付きのLSTMやHyperLSTMを用いることができるようになっており,またKL Lossのアニーリングも行なっていたのですが今回の実装ではそれらは含まれていなかったためこのような結果になったのではないかと考えています.入力画像に関係なく毎回同じようなスケッチを生成することに関しては,特にKL Lossのアニーリングを行なっていないことで,Reconstruction Lossに比べてKL Lossにばかり重点が置かれたことが原因だと考えています.

参照

最後に

尚,エクサウィザーズは20卒向けのAIエンジニアのポジションで,内定直結型インターンを東京・京都オフィスで募集しています.ご興味を持たれた方はぜひご応募ください.

ご応募はこちらから

Tensorflow.jsを用いたブラウザで動く物体認識

こんにちは。エクサウィザーズAIエンジニアの須藤です。 この度exaBaseの「物体名判別」モデルの紹介ページに、その場で試せるデモ機能を追加しました。

f:id:kentaro-suto:20181029164025p:plain

前回の「写真に写っていないところを復元する」とともに、実装にあたってはTensorflow.jsというフレームワークを使っています。 この記事では、Tensorflow.js導入までの簡単な解説と注意点、および新しいデモの操作方法を紹介したいと思います。

Tensorflow.jsとは

TensorflowもしくはKerasで書かれた機械学習モデルを、JavaScriptで扱えるようにするフレームワークです。 学習済みモデルによる推論が主な応用と考えられますが、モデルの構築や再学習も可能です。 WebGL経由でGPUを利用するので、計算は十分に高速です。

公式サイト

www.tensorflow.org

特徴

Webブラウザ上でAIモデルが動くようになります。 以下の特徴を持ったアプリケーションが作れます。

  • ライブラリ等のインストールが不要
    • (最新の)Webブラウザさえあれば動作します。
  • GUIが利用可能
    • Webアプリの技術がそのまま使えます。
    • ローカルファイルの選択などもOSやブラウザが面倒を見てくれます。
  • スマートフォンでも動作
  • サーバの負荷が少ない
    • 計算はクライアント側で行われるため
  • 細かいカスタマイズはできない

開発環境

pipでインストールできます。

pip install tensorflowjs

モデルの書き出し

モデルデータの書き出しには、コマンドtensorflowjs_converterまたは、Pythonフレームワークtensorflowjsが使えます。 後者で、既存の推論スクリプトに一時的に以下の行を書き足す方法が、手軽でおすすめです。

import tensorflowjs as tfjs
# モデルオブジェクトmodelを、ディレクトリ'tfjs'に、8ビットの量子化を用いて書き出す。
tfjs.converters.save_keras_model(model, 'tfjs', quantization_dtype=np.uint8)

実行すると、書き出し先ディレクトリに次のようなファイル群が生成されます。ファイルの名前と数はモデルのサイズや構成によって変わります。 総ファイルサイズは、重み(.hdf5)ファイルにほぼ比例します。量子化しなければそのまま、16ビットで半分、8ビットで4分の1になります。

model.json
group1-shard1of3
group1-shard2of3
group1-shard3of3

モデルの読み込み

HTML側でフレームワークをインポートします。

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.13.0"></script>

モデルの読み込みを開始するには、JavaScriptで次のようにします。結果はPromiseオブジェクトを通して非同期にハンドラ関数に渡されます。

tf.loadModel('./model.json').then(handleModel).catch(handleError);
function handleModel(model) {
    // 正常に読み込まれた時の処理
    // 必要なら入出力shapeを保存
    height = model.inputs[0].shape[1];
    width = model.inputs[0].shape[2];
    // modelの操作...
}
function handleError(error) {
    // エラー処理
}

実行

Pythonで書かれた前処理や後処理を、頑張ってJavaScriptに翻訳します。実行そのものはモデルのpredictメソッドを呼ぶだけです。 predictを呼び出してから結果データを受け取るまでの流れは、形式上非同期ですが、Mac版Safariの場合、結果を待つ間もJavaScriptに制御が戻りませんでした。

var data = new Float32Array(height*width*3);
// dataの前処理...
var inputs = tf.tensor(data).reshape([1,height,width,3]); // テンソルに変換
var outputs = model.predict(inputs);
outputs.data().then(handleData).catch(handleError);
function handleData(data) { // Float32Arrayを受け取る    
    // dataの後処理...
}

モデルを読み込めない場合

原稿執筆時のバージョン(0.13.0)では、モデルに以下のものが含まれていると、書き出せても読み込めませんでした。 動作に影響しないものなら、model.jsonを書き換えることで解決する場合もあります。 どうしても必要な演算の場合、行列演算や畳み込みで代用するなどの工夫が必要です。

  • Subtract
    • AddやMultiplyは大丈夫なのに何故かこれはダメです。
  • Merge
  • Lambda
  • カスタマイズされたinitializer、regularizer、constraints
  • 名前の重複したレイヤー

物体名判別デモ

Kerasに同梱されている学習済みXceptionモデルを用いたデモです。 同モデルの重みはMITライセンスの下で公開されています。

こちらのページでお試しいただけます。 https://base.exawizards.com/view/modelDetail?id=6

操作方法

f:id:kentaro-suto:20181023183149p:plain 画面左下のリンクをクリックしてください。モデルのダウンロードが始まります。
f:id:kentaro-suto:20181023183216p:plain ダウンロードが終わると、プルダウンとファイル選択ボタンが表示されます。
f:id:kentaro-suto:20181024134942p:plain サンプル写真をプルダウンから選択するか、
f:id:kentaro-suto:20181023183415p:plain ローカルのファイルを選択してください。
f:id:kentaro-suto:20181023183443p:plain 計算が始まります。
f:id:kentaro-suto:20181030165119p:plain 最も確率が高い1〜3カテゴリに関するコメントと、上位10カテゴリの確率が表示されます。

結果

いくつか興味深い結果をご紹介します。なお、この記事で用いた写真およびデモページのサンプル写真は、私が個人的に撮影したものです。

画像 コメント
f:id:kentaro-suto:20181029143853p:plain バターナット・スクウォッシュを調べたら、確かに似ていました。ヒョウタンは学習データにありませんでした。
f:id:kentaro-suto:20181029160544p:plain どうやら本当はキジバトのようです。このように、知らないものに対しては、学習した中で似ているものを返します。
f:id:kentaro-suto:20181029144117p:plain 食べ物というところまでは合っています。
f:id:kentaro-suto:20181029144847p:plain 猫判別能力は概ね高いのですが、これは見破れなかったようです。
f:id:kentaro-suto:20181029144431p:plain メインの被写体について何も分からない場合、小さくても隅っこでも、知っているものに反応する場合があります。
f:id:kentaro-suto:20181030165004p:plain 右下の黒くて四角い何かに強く反応しました。比べると左上の生き物については、いまいち確証が無かったようです。
f:id:kentaro-suto:20181029144914p:plain 見えない何かに反応することも。生き物に関しては、それがいそうな背景でも判断しているようです。
f:id:kentaro-suto:20181029144722p:plain かたくなに犬と言い張ります。どうも縞模様で猫を判別しているようです。
f:id:kentaro-suto:20181030165034p:plain やっぱり。しかしこれだけ抽象化されたデザインに反応するのは珍しいです。
f:id:kentaro-suto:20181030165342p:plain 例えばシマウマは知っています。
f:id:kentaro-suto:20181029161711p:plain しかしこれはシマウマとは判定されません。
f:id:kentaro-suto:20181029144814p:plain 納得しかけましたが、よく見たら貯金箱じゃありませんでした。

まとめ

Tensorflow.jsの使い方と、それを用いたデモを紹介しました。 前処理があまり必要ない場合は、想像より簡単にモデルを動かせるようになります。 exaBaseの他の既存モデルのデモもおいおい追加して行きたいと思っています。

尚、エクサウィザーズは優秀なエンジニア、社会課題を一緒に解決してくれる魔法使い”ウィザーズ”を募集しています。ご興味を持たれた方はぜひご応募ください。 採用情報|株式会社エクサウィザーズ

ExaWizards Engineer Blogでは、定期的にAIなどの技術情報を発信していきます。Twitter (https://twitter.com/BlogExawizards) で更新情報を配信していきますので、ぜひフォローをよろしくお願いします!

写真に写っていないところを復元する

f:id:kentaro-suto:20181024193158p:plain

こんにちは。エクサウィザーズAIエンジニアの須藤です。

みなさんはハイキングの写真でしずちゃんばかり写して、まともに撮られなかったジャイアンに殴られかけたことは無いでしょうか。 そんなとき「万能プリンター」があったら便利ですね。もう撮ってしまった写真の、向きやズームを後から修正して、写ってなかったところを復元して再プリントできるというものです。 しかし持ち主であるドラえもんは、うちにもまだ来ていません。仕方がないのでAIの力でなんとかしましょう。

目的

写真の外側に写っているものを推測し、自然な形で合成します。

f:id:kentaro-suto:20181024190925p:plain:w200

物体の部分画像からその種類ないし位置を推測し、既存画像を本に全体を復元することが、原理的には可能なはずです。 その過程を直接にプログラムすることは現実的ではありません。 代わりに畳み込みニューラルネットワーク(CNN)に、かいつまんで学習させます。

学習モデル

敵対的生成ネットワーク(GAN)のアルゴリズムに従います。目的に書いたとおりのことを行う生成器と、画像の本物らしさを判別する判別器を、交互に学習します。 f:id:kentaro-suto:20181024153648j:plain

  • 生成器を以下のように構成します。
    • 画像を特徴量マップにする多層CNN
    • 特徴量マップを縦横二倍に広げる多層転置CNN
    • 特徴量マップから画像を生成する多層転置CNN
  • 事前学習を行います。
    • 真ん中だけを切り取った画像から元の画像を生成するように、生成器を学習
  • 画像が本物かを判定する判別器を構成します。
  • 以下の学習を交互に繰り返します。
    • 生成器を通した画像に対して0、本物の画像に対して1を返すように、判別器を学習
    • 生成した画像に対して判別器が1を返すように、生成器を学習

データセット

はじめ様々な物体を含む画像データセットで学習を行いましたが、外周部のピクセルの色を拡散させる以上のことをしてくれませんでした。 どんな写真にも対応させるには、学習時間またはモデルの性能が不足しているようです。

そこで顔画像データセットであるLabelled Faces in the Wildを用いて学習することにしました。 処理すべき情報量が減るため、少なくとも顔の部分に関しては精度の高い結果が期待できます。 一方で、使い道が想像つかなくなりましたが、モデルの検証が目的ということでご理解ください。

結果

入出力の例です。本来の全体画像も正解として記載します。 入力画像はエクサウィザーズのメンバー紹介ページから借りました。

入力 出力 正解 コメント
f:id:kentaro-suto:20181024173438p:plain f:id:kentaro-suto:20181023105104p:plain f:id:kentaro-suto:20181023105052p:plain ネクタイができかけています。
f:id:kentaro-suto:20181024173458p:plain f:id:kentaro-suto:20181023105619p:plain f:id:kentaro-suto:20181023105603p:plain なぜか法衣みたいに。
f:id:kentaro-suto:20181024173509p:plain f:id:kentaro-suto:20181023104451p:plain f:id:kentaro-suto:20181023104439p:plain 背景のボケがいい感じに。予測モデル的には失敗です。
f:id:kentaro-suto:20181024173538p:plain f:id:kentaro-suto:20181023104803p:plain f:id:kentaro-suto:20181023104751p:plain これまた作務衣のよう。
f:id:kentaro-suto:20181024173600p:plain f:id:kentaro-suto:20181023105944p:plain f:id:kentaro-suto:20181023105934p:plain ワイルドに。
f:id:kentaro-suto:20181024173618p:plain f:id:kentaro-suto:20181023110015p:plain f:id:kentaro-suto:20181023110003p:plain ムーディーに。
f:id:kentaro-suto:20181024173631p:plain f:id:kentaro-suto:20181023105919p:plain f:id:kentaro-suto:20181023105907p:plain 背景色がきっちり伸ばされているところに注目。
f:id:kentaro-suto:20181024173712p:plain f:id:kentaro-suto:20181023104826p:plain f:id:kentaro-suto:20181023104814p:plain 髪型提案モデルとして使えるかも。
f:id:kentaro-suto:20181024173730p:plain f:id:kentaro-suto:20181023104742p:plain f:id:kentaro-suto:20181023104734p:plain はみ出るのが少ないと変な結果になりにくいです。
f:id:kentaro-suto:20181024173903p:plain f:id:kentaro-suto:20181023104723p:plain f:id:kentaro-suto:20181023104712p:plain 髪の毛が大増量です。
f:id:kentaro-suto:20181024173924p:plain f:id:kentaro-suto:20181023105652p:plain f:id:kentaro-suto:20181023105633p:plain 無地の背景はなんとかして避けようとします。
f:id:kentaro-suto:20181024173949p:plain f:id:kentaro-suto:20181023105721p:plain f:id:kentaro-suto:20181023105706p:plain 首を隠すとなぜか太くなりがちです。
f:id:kentaro-suto:20181024174015p:plain f:id:kentaro-suto:20181023105551p:plain f:id:kentaro-suto:20181023105521p:plain 襟があるかないか決めかねたようです。
f:id:kentaro-suto:20181024174026p:plain f:id:kentaro-suto:20181023111924p:plain f:id:kentaro-suto:20181023111914p:plain 妙にごつくなりました。
f:id:kentaro-suto:20181024174041p:plain f:id:kentaro-suto:20181023105757p:plain f:id:kentaro-suto:20181023105735p:plain やたらきらびやかになりました。
f:id:kentaro-suto:20181024174114p:plain f:id:kentaro-suto:20181023105438p:plain f:id:kentaro-suto:20181023105428p:plain ビシッと黒スーツです。

髪の毛や顎など、画像中に無いものが付け足されています。背景や髪の毛の色もある程度反映されます。手がかりが全くない場合、髪型は丸刈り、服装は黒スーツになるようです。

デモ

ブラウザ上で動作するデモを用意しました。こちらのページですぐに試すことができます。 https://base.exawizards.com/view/modelDetail?id=45

操作説明

f:id:kentaro-suto:20181023102234p:plain ページ下部のリンクをタップしてください。モデルのダウンロードが始まります。12MBほどあるので、通信環境によってはお時間をいただきます。
f:id:kentaro-suto:20181023104008p:plain お手持ちの顔写真を選択してください。計算は全てブラウザ上で行われます。写真データが外部に送信されることはありません。
f:id:kentaro-suto:20181023102321p:plain 写真が読み込まれました。写真を直接ドラッグすることで位置を、スライダーではスケールを調節できます。
f:id:kentaro-suto:20181023102406p:plain 顔が白い枠いっぱいに表示される状態になったら、決定ボタンを押してください。計算が始まります。
f:id:kentaro-suto:20181023102421p:plain 枠内の内容だけから再生成された画像が表示されます。別のファイルを選択するか、位置を変える操作を行うと、再計算ができます。

まとめ

学習に基づき写真の範囲を拡張するAIモデルを作りました。

本手法はこんなときに役に立つ、かもしれません。

  • 不完全な写真からポートレートを作成
  • 構図の良くない写真を修正
  • 写真一枚から環境マップを生成
  • 4:3の映像を16:9のテレビで見る際の余白埋め

尚、エクサウィザーズは優秀なエンジニア、社会課題を一緒に解決してくれる魔法使い”ウィザーズ”を募集しています。ご興味を持たれた方はぜひご応募ください。 採用情報|株式会社エクサウィザーズ

ExaWizards Engineer Blogでは、定期的にAIなどの技術情報を発信していきます。Twitter (https://twitter.com/BlogExawizards) で更新情報を配信していきますので、ぜひフォローをよろしくお願いします!