エクサウィザーズ Engineer Blog

株式会社エクサウィザーズのエンジニアチームブログ

因果推論とグラフ理論

こんにちは。数理最適化ギルドでエンジニアをしている加藤です。

ある自社プロダクトの開発を通じて因果推論について勉強する機会がありました。因果推論は統計の分野ですが、その中で数理最適化の技術が使えることを知り、とても面白かったのでその内容をシェアしようと思います。具体的には組合せ最適化問題のひとつである最小カット問題が、因果推論のタスクの一部である識別可能性に利用できるという話をします。

前半は因果推論についての概説で特に予備知識は仮定していないです。後半は計算時間やネットワークフローなどのアルゴリズムを知っていると読みやすいと思います。

因果推論とは

因果推論の目的

統計的因果推論とは事象の間の因果効果を実験データや観測データから推定することを目的とした統計学の一分野です。単に因果推論といった場合は統計的因果推論を含むより広い概念を指すことがありますが、簡単のため以下では因果推論といえば統計的因果推論のことであるとします。

例えばある販促キャンペーンを実施したときに、本当に効果があったのかどうかなどを合理的に判断することなどに利用できます。この例のように物事の原因となったと想定されるものを処置(トリートメント)変数、その原因に影響を受けたと想定されるものを結果(成果・アウトカム)変数と言います。また処置が結果に与えた影響の大きさを因果効果と言います。

因果推論には主にRubin流の潜在的アウトカムを用いる流儀とPearl流の因果グラフを用いる流儀があります。本稿で説明するのは後者のPearl流因果推論です。

因果関係があるどうかを判断するというのは意外と難しく、その理由の一つは人間は相関関係を因果関係を誤認してしまうことにあります。

因果≠相関

相関・因果の誤謬の例としてしばしば引き合いに出されるのが次のグラフです[1]。

チョコレート消費量とノーベル賞受賞者数の関係

これは一流の医学雑誌に掲載された論文の分析で「一人当たりのチョコレートの年間消費量が多い国ほどノーベル賞受賞者が多い」という相関関係がわかります。論文の著者は「国民が1人あたり年間400g多くチョコレートを摂取すると、その国のノーベル賞受賞者の数が1人増える」と結論づけていますがこれは当然正しい推論ではなく、実際は「一人当たりGDP」という変数が前者と後者の変数どちらには影響を与えることによって生み出されている擬似的な相関です。疑似的な相関(あるいは疑似的な無相関)のことを交絡といい、交絡を作り出すの「一人当たりGDP」のような変数を交絡変数と呼びます。こういった擬似相関を排除したい場合はどうすれば良いのでしょうか?

もう少し単純な例で考えてみます。次の表をご覧ください。

治療結果の表

この表は何かしらの疾病・怪我などに対して治療した場合としなかった場合での生存・死亡数を表しています。治療をした場合もしなかった場合も生存率が50%で変わっていないので一見この治療には効果がないように見えます。しかし次の表を見てください。

男女別治療結果の表

この表をみると男性を治療した場合の生存率は約62%、しなかった場合は約57%なので治療によって生存率が上がっています。一方女性も治療した場合が約44%、しなかった場合が40%なのでこちらも生存率が上がっています。つまり一見効果がないように見えた治療でしたが、男女別にすると治療効果が見えてくるということがあります。

このように層別化などによって正しい因果効果が推定できるような変数の集合のことを調整変数と呼びます。また調整変数が観測できるとき因果効果は識別可能であると言います。つまり上の例では「チョコレート消費量→ノーベル賞受賞者数」の因果効果を正しく推定するための調整変数が「一人当たりGDP」なので、この因果効果は識別可能であるということになります。

調整変数選択の難しさ

調整変数によって層別化すれば正しい因果効果を推定できると述べましたが、肝心の調整変数はどうやって選択すれば良いのでしょうか?少なくともなんでもかんでも調整変数に含めれば良いわけではないということが次の例でわかります。次の表を見てください。

幼児が遊んだトランプの表

この表は幼児が遊んだトランプの分類です。数値的には先ほどの表と全く同じです。汚れありのグループをみると絵札より数札の方が赤札の割合が高いです。一方汚れなしのグループも絵札より数札の方が赤札の割合が高いです。つまり「汚れの有無にかかわらず、絵札より数札の方が赤札の割合が高い」と読めます。しかし絵札も数札も赤と黒は同数あるので「絵札より数札の方が赤札の割合が高い」という主張は当然正しくありません。したがってこの例では「汚れの有無」という属性で分けることで間違った結論を導いてしまいます。

上記の例のように層別化することでバイアスが発生してしまうことを選択バイアスと呼びます。*1

このように何を調整変数にすべきか・そもそも調整変数は存在するのか(識別可能どうか)を判断するのは難しい問題ですが、その問いへの回答として冒頭でも触れたPearlの因果グラフの理論があります。続く記事ではその内容を解説します。

因果グラフ

グラフとは

グラフとは物事のつながりを表現するための構造のひとつです。現実世界のさまざまなものをモデル化する力があり、因果推論以外の文脈でも頻繁に現れる対象です。

数学的にはノードとエッジと呼ばれる対象から構成され、ノードはグラフの頂点でエッジはふたつのノードをつなぐ辺を意味します。フォーマルな定義は以下です。ついでに後々使う概念の定義もしておきます。

定義(グラフ)G = (V, E)をグラフという。Vの要素をノードあるいは頂点という。EVの要素のペアからなる集合で、その要素をエッジあるいは辺という。エッジに向きが指定されている場合そのエッジは向きがついているという。またUからVへの向きがついたエッジはU \to Vと書く。またUをエッジの始点、Vをエッジの終点という。

定義(無向グラフ、有向グラフ、単純グラフ)全てのエッジに向きが付いているグラフを有向グラフという。逆に全てのエッジに向きがついていないグラフを無向グラフという。全てのノードのペアに対して高々一つしか辺がなく、全てのノードが自分自身へのエッジを持たないグラフを単純グラフという。以後グラフは全て単純であるとする。

定義(パス、ループ)ノードの列N_0, N_1, \cdots, N_pであって全てのiに対してN_iN_{i + 1}の間にエッジが存在するものをパスという。またそれらが全てN_iからN_{i + 1}へ向きがついている場合、このパスを有向パスという。またN_0およびN_pをそれぞれパスの始点および終点という。始点と終点が一致しているパスをループと呼び、それが有向パスである場合に有効ループと呼ぶ。

定義(親、祖先、子、子孫)有向グラフのノードNに対してNに入ってくるエッジの始点であるノードをNの親、Nから出ていくノードの終点であるノードを子、Nを終点とする有向パスの始点を祖先、Nを終点とする有向パスの終点を子孫という。

定義(DAG)単純な有向グラフであって有向ループを持たないものをDAG(有効非巡回グラフ)という。

水色の丸がノード 青い矢印がエッジこの例は特にDAGになっている

因果推論においてはノードに事象あるいは変数を対応させ、エッジは因果関係を表します。このように因果関係をグラフで表現したものを因果グラフと呼びます。また因果グラフ上でノードUからノードVにエッジが存在するときUVの原因であるということがあります。

交絡の例

因果グラフを使うと交絡変数は次のように表せます。ただしT, Y, Uがそれぞれ処置変数、結果変数、交絡変数を表します。以下で矢印の実線は因果関係を、波線は交絡を表していると考えてください。

交絡変数によって処置変数と結果変数が関係し合っている

交絡変数Uで調整することにより交絡を取り除く様子を次のように表現します。

交絡変数Uを調整することにより交絡を取り除く

逆に調整変数にすべきでない変数による交絡が生じる場合もあります。以下のように調整変数も結果変数も原因であるような変数で調整すると選択バイアスによる交絡が生まれます。

Sを調整変数に選択することによって交絡が生まれる

このように因果関係をグラフによって表現し、交絡変数や選択バイアスを視覚的に表現できます。もう少し複雑な例を考えてみましょう。次のグラフをご覧ください。

複雑な因果グラフ

上記の例ではまずVが交絡変数になっています。またUTVの交絡因子になっており、VYの原因なのでT \gets U \to V \to Yというパスを介してTYは関係しあっています。同様にT \gets V \gets W \to YというパスもありこれもTYの交絡を作り出しています。このような処置変数Tに入ってくるエッジからスタートする結果変数Yへのパスをバックドアパスと呼びます。

3つの交絡

これらの交絡を取り除くためにはどのような調整変数を選択すれば良いでしょうか。実はVを調整変数にすると上記3つの交絡は全て取り除けます。なぜなら上記3つの交絡をもたらすバックドアパスの経路上にVが存在して交絡をブロックしているためです。バックドアパスとパスのブロックのフォーマルな定義は以下になります。

定義(バックドアパス)有向グラフG = (V, E)の頂点T, Y \in Vに対してTからYへの有向とは限らないパスであってTへ入ってくるエッジからスタートするものをバックドアパスという。

定義(パスのブロック)有向グラフG = (V, E)とその頂点の部分集合S \subset VおよびG上の有向とは限らないパス\gamma が次のいずれかの条件を満たすとき S\gamma をブロックするという。

1. \gamma上の合流点であって、自分を含む全ての子孫がSに含まれないものが存在する。 2. \gamma上の非合流点であってSに含まれるものが存在する。

ただしパス上の端点以外のノードであってパス上で隣り合う二点からエッジが入ってくるようなノードを合流点と呼んでいる。

しかし一方で別の問題が発生します。それはVを調整変数にすることでU \to V \gets Wに選択バイアスが生まれT \gets U \to V \gets W \to Yという新たな交絡が生じてしまうのです。この新たな交絡を取り除くためには{U, V}などを調整変数として選ぶ必要があります。

Vで調整すると新たな交絡が生まれるのでUも選ぶ

上のような特殊な例に限らず交絡を取り除くためにはどんな調整変数を選べば良いのでしょうか。その答えが次のd分離性です。

定義(d分離)有向グラフG = (V, E)の頂点T, Y \in Vと頂点の部分集合S \subset Vが次の性質を満たすときS\{T, Y\}をd分離(有効分離)するという。TからYへの全ての有向とは限らないパス\gammaに対してS\gammaをブロックする。

少し面倒ですが上述の特殊な例において\{U, V\}\{T, Y\}をd分離していることが確認できます。

有効分離があるのに無向分離はないのかと思われるかも知れませんが一応こちらもあります。後で使うのでここで定義しておきます。

定義(u分離)有向グラフG = (V, E)の頂点T, Y \in Vと頂点の部分集合S \subset Vが次の性質を満たすときS\{T, Y\}をu分離(無向分離)するという。TからYへの全ての(無向)パス\gammaに対して\gamma上にSの要素が少なくとも1つ存在する。

バックドア基準

さらに一歩進んでバックドア基準についても述べておきます。交絡を取り除いて因果効果のみを知りたいときに有用な概念です。

定義(バックドア基準)因果グラフG = (V, E)とそのふたつの頂点T, Y \in Vに対して次の条件を満たすノードの集合S \subset Vをバックドア基準を満たすという。

  1. すべてのノードZ \in Sに対してTからZへの有向パスが存在しない。
  2. \{T, Y\} \cup Sの祖先グラフ上でSTYをd分離する。

ただし有向グラフG = (V, E)のノードの集合V_0 \subset Vに対してその祖先グラフとは、ノードV^{an}V_0のすべての要素およびその祖先とし、エッジE^{an}Gのエッジであってその始点・終点がV^{an}の要素であるものとするグラフのことである。

やや定義が複雑ですがバックドア基準を満たす変数集合を調整変数として選べば介入効果などが表現できることが知られています[2]。

ところで処置変数の親を全て調整変数にすれば必ずバックドア基準を満たすことも知られています。しかしこれは必ずしも良い選択ではありません。調整変数は層別化や回帰といった用途に用いられる変数なので、できるだけ少ない数の調整変数が知りたいという欲求があります。

つまり我々の目標は次のようになります

バックドア基準を満たす最小の変数集合を見つけたい!

定義よりバックドア基準はd分離性と密接に関わっていることがわかると思いますので次のような目標設定でも良いです。

d分離性を満たす最小の変数集合を見つけたい!

次の章では上の目標を実現するアルゴリズムについて解説します。

調整変数選択のアルゴリズム

d分離性の言い換え

ここではd分離性を満たす最小の変数の集合をどうやって見つけたらよいかを、計算量にも触れつつお話しようと思います。

ひとつの方法は変数の部分集合を全て列挙してその集合がd分離性を満たすかどうかを判定することです。この方法はすぐに思いつきますし、実際にこのように実装されているライブラリもあるようです。実用上はあまり問題がないかも知れませんが、変数の集合が多くなるとこの方法は非常に遅くなります。実際変数がN個だとすると計算時間がO(2 ^ N)もかかるので規模が大きくなると途端に遅くなるでしょう。

アルゴリズムに詳しい人ならd分離性の定義を見たときに「ソース(処置変数)からシンク(結果変数)への全てのパスが何かしらの意味でブロックされているという定義だから、最小カットのようなアルゴリズムが使えそうだ」と思いつくかも知れません。これから述べるようにその直感は正しいのですが、最小カットを応用するためにはd分離性の定義を言い換える必要があります。それが次のモラルグラフを使った言い換えです。

定義(モラルグラフ)G = (V, E)を有向非巡回グラフとする。GのモラルグラフG^mとはVをノードとし、次で定義するエッジの集合をもつ無向グラフことである。Vの2頂点v_1およびv_2に対して次のいずれかが成り立つとき\{v_1, v_2\}G^mのエッジであるとする。

  1. v_1 \to v_2あるいはv_2 \to v_1Gのエッジである。
  2. G内で、v_1およびv_2を共通の親としてもつノードが存在する。

モラルグラフの作り方

ここでモラルグラフの構築にかかる計算量を調べてみると、全てのノードのペアに対して上記2つの条件のうちいずれかを満たすかどうかをチェックするだけなのでO(N^2)程度です。モラルグラフに対して次の定理が成り立つため非常に有用な概念になっています[3]。

定理([3], Theorem 3)G = (V, E)を有向非巡回グラフとする。頂点\{T, Y\}に対してS\{T, Y\}の祖先の集合の部分集合であるとする。このときS\{T, Y\}をd分離することは\{T, Y\}の祖先グラフのモラルグラフH = G_{An(T, Y)}^m内でS\{T, Y\}をu分離することと同値である。

この言い換えによりd分離集合を求めるためには

  1. 処置変数と結果変数の祖先グラフのモラルグラフを構築する。
  2. モラルグラフ内でu分離する頂点集合を求める

というステップを踏めば良いことがわかりました。1.に関しては解説しましたので次は2.について解説していきます。

最小カット問題

モラルグラフ内でu分離する頂点集合を求めるというタスクは最小カット問題で解くことができます。最小カット問題とは次のような問題です。ここだけの用語ですが有向グラフGの2頂点u, vとエッジの集合Fに対してGからFを取り除いたグラフにおいてuからvへの有向パスが存在しないとき、Fu, vをカットするということにします。

有向グラフG = (V, E)がある。また全てのエッジにはコストという値が定められているとする。始点Sと終点Gが定められている。始点と終点をカットする「エッジのコストの合計」の最小数はいくらか?

カットの例

モラルグラフのu分離性に使えそうな設定の問題ですね!でも後一歩必要です。最小カット問題はエッジを除去することでパスをブロックする問題ですが、u分離性は頂点によってパスをブロックするからです。このギャップを埋めるためには次のようなグラフを考えれば良いです。

モラルグラフHのノードuに対してu_+およびu_-というノードとu_+ \to u_-というエッジを作りコスト1を与えます。またuvをつなぐエッジがHに存在する場合、u_- \to v_+およびv_- \to u_+というエッジを作りコスト\inftyを与えます。以上で作ったノードとエッジからなるグラフを\tilde{H}と書きます。

最小カットに帰着する

H\tilde{H}に対して次のことが言えます。

H内でTYをu分離するノードの集合の最小数と、\tilde{H}内でT_-Y_+の間の最小カット数は同じである。また次のように一方の問題の解から他方の問題の解への変換ができる。

  1. HにおいてST, Yをu分離するノードの集合だとすると、\tilde{H}においてF = \{u_+ \to u_-\ | u \in S\}T_-, Y_+をカットする。
  2. \tilde{H}においてFT_-, Y_+をカットするエッジの集合だとすると、HにおいてS = \{u | u_+ \to u_- \in F\}T, Yをu分離する。

これでモラルグラフのu分離集合の求め方がわかりました!これとd分離性のモラルグラフでの言い換えと合わせると結局d分離性を満たす最小の変数集合を高速に計算できることになります!

また使ったアルゴリズムも最小カットなので適当なソルバーを使えば高速に計算できます。\tilde{H}のエッジおよびノードはそれぞれ高々O(N^2)およびO(N)のオーダーなのでDinicのアルゴリズムなどで最悪O(N^4)程度の計算時間になります。愚直な方法が指数時間であることを考えると劇的な改善です!

終わりに

今回は因果推論におけるアルゴリズムの活躍についてお話ししました。筆者の乱文をここまで読んで下さった方には大変感謝いたします。以下のリストは調整変数の選択とアルゴリズムに関連する内容として興味深いですが筆者の勉強と気力が不足していて書けませんでした。またTechBlogを書く機会があってやる気に満ち溢れていたら勉強して書こうと思います。

  • フロントドア基準・操作変数法
  • optimalな調整変数の選択
  • mixed graphでの識別可能性
  • IDアルゴリズム

エクサウィザーズでは一緒に働く人を募集しています。中途、新卒両方採用していますので、興味のある方は是非ご応募ください!

参考文献

  1. Messerli, F. H. (2012). Chocolate Consumption, Cognitive Function, and Nobel Laureates. The New England Journal of Medicine, 367, 1562-1564.
  2. 宮川雅巳. (2004). 統計的因果推論ー回帰分析の新しい枠組みー. 朝倉書店.
  3. Tian, J., & Paz, A., & Pearl, J. (1998). Finding Minimal D-separators. Tech. Rep. R-254, Univ. of California, Los Angeles.
  4. Acid, S., & De Campos, L. M. (2013). An algorithm for finding minimum d-separating sets in belief networks. arXiv preprint arXiv:1302.3549.

*1:選択バイアスはより広い意味で用いられることがありますがこの記事では上記の意味でのみ用います。

Zero-shot Learning入門

こんにちは。エクサウィザーズで画像ギルドに所属し、機械学習エンジニアをしている小島です。今年の3月からこちらにジョインいたしました。

この記事では、弊チームで取り組んいるテーマ「Zero-shot Learning」について、歴史的な背景を振り返りつつ、簡単な実装を紹介します。今研究でホットな研究テーマの一つである「クロスモーダルモデル」を身近に感じていただければ幸いです。

Zero-shot Learningとは

「Zero-shot Learningとは何か」というのは、実は曖昧なテーマです。「これがZero-shotだ」という定義が論文によって異なるためです。わかりやすい理解の仕方としては、Many-Shot Learning、One/Few-shot Learningから天下り的に考えていくことでしょう。

画像系の機械学習の問題は、大きく分けて、タスクの軸データ数の軸の2軸で考えられます。

タスクの軸については、分類問題(Classification)、物体検出(Object Detection)、セマンティック/インスタンスセグメンテーション(Instance/Semantic Segmentation)を最も基本的な3つのタスクとして挙げています。タスクの軸はこれ以外にもいろいろあるので、この3つが正解というわけではありません。

本題はデータ数の軸で、ここでは「タスク固有の訓練データ」を意味します。例えば、「犬か猫か」の画像分類モデルを訓練したい場合、犬の画像を数百枚、猫の画像を数百枚持ってきて、ImageNetで訓練されたモデルをfine-tuningするのが一般的なやり方でしょう。この場合、タスク固有の訓練データは「数百枚×クラス数(犬or猫=2)」必要となり、データ数の軸では「Many-shot」となります。

では、Few/One-shotとは何でしょうか。One-shotとは文字通り、タスク固有の訓練データが1つ、Fewなら数枚か~少し多い程度でしょうか。Few/One-Shotの典型例は顔認証です。例えば、1人あたり数百枚の顔写真を入れ、顔認証をMany-shotの問題として訓練・運用するのは現実的でありません。データが取れないという問題もありますし、認証対象に1人追加されるとクラス数が変わるため、モデル全体を訓練し直さないといけないからです。この図は[1]の論文からのものです。

One-shot Learningでは、「クラスが同一かどうかを学習している」点に注意してください。Many-shotでは、犬や猫といった特定のクラスに属するかを学習していました。ここで大事なのが、学習やタスクの設定を工夫すれば、タスク固有の訓練データは減らせるという点です。ここの改良はZero-shotでも続けられています。

Zero-shot黎明期

ディープラーニングのZero-shotの論文に行く前に、Zero-shotの発端となりうる論文を2つ紹介します。

Zero-shot Learning

1つ目は、Palatucci et al.(2009)によって書かれた論文です[2]。AlexNet[3]が2012年に発表されたので、これはディープラーニングのブームが来る前の論文です。論文のタイトルにも「Zero-shot Learning」とついていますが、本文中にZero-shot Learningについて問題提起しています(※翻訳はDeepLで作成しています)。

Given a semantic encoding of a large set of concept classes, can we build a classifier to recognize classes that were omitted from the training set?
大規模な概念クラス集合の意味論的符号化が与えられたとき、学習集合から漏れてしまったクラスを認識する分類器を構築できるか?

また、アブストラクトでは、Zero-shot Learningを以下のように定義しています。

We consider the problem of zero-shot learning, where the goal is to learn a classifier f : X \to Y that must predict novel values of Y that were omitted from the training set.
学習セットから漏れてしまった新しいYの値を予測しなければならない分類器f : X \to Yを学習する、ゼロショット学習の問題を考える。

訓練データにはないクラスを認識することがポイントで、この観点は現在のZero-Shot Learningでも根底にある思想の1つです。

なお、この論文の著者にはジェフリー・ヒントンとトム・ミッチェルがいます。ジェフリー・ヒントンはディープラーニング界のゴッドファーザーの1人として非常に有名な研究者です。トム・ミッチェルの「機械学習の定義」は非常によく引用されるので、どこかで見た方もいらっしゃるのではないでしょうか。彼の著書[4]からです

A computer program is said to learn from experience E with respect to some class of tasks T and performance measure P, if its performance at tasks in T, as measured by P, improves with experience E.
コンピュータ・プログラムが、あるタスクのクラスTと性能指標Pに関して、経験Eから学習すると言われるのは、Tのタスクにおける性能が、Pによって測定されるように、経験Eとともに向上する場合である。

Zero-data Learning

2つ目はLarochelle et al.(2008)らによって書かれた、「Zero-data Learning of New Tasks」という論文[5]です。「Zero-data Learning」と「Zero-shot Learning」紛らわしいですね。図はZero-data Learningの論文からです。

「Zero-data Learning」は以下のように定義されています。

We introduce the problem of zero-data learning, where a model must generalize to classes or tasks for which no training data are available and only a description of the classes or tasks are provided.
ゼロデータ学習とは、学習データがなく、クラスやタスクの説明のみが提供されているクラスやタスクに対してモデルを汎化しなければならない問題である。

Zero-data Learningの特徴として「学習データがない」という点が明確に記載されています。また、クラスやタスクの説明のみが提供されているということも興味深い。実はこのZero-data Learningの説明こそ、後で紹介するCLIPのやっていることと一致するという点も、頭の片隅に入れておきたい点です。

論文の図を見ると、左側は手書き数字の「1,2,3」のみ訓練データとして与え、「A, B」を推論するケースで、これは完全に学習データがないケースです。真ん中はマルチタスク学習で、欠損値補完のイメージでしょうか。いずれにしても「学習データがないクラスや組み合わせに対して推論する」という点がポイントです。

冒頭で、「現在のZero-shot LearningはMany-shotやFew/One-shot Learningから天下り的に考えると、タスク固有の訓練データがない学習方法だ」ということを述べました。まさにこれが、Zero-data Learningの考え方と重なるわけです。

すなわち、現在のディープラーニングでZero-shot Learningと呼ばれているものは、厳密にいえば、Palatucci el al. (2009)のZero-shot Learningと、Larochelle et al. (2008)のZero-data Learningの少なくとも2つの側面があると言えます。現状では、各論文の著者が「これがZero-shotだ」と言っているので、源流の論文を明確に意識する機会はほとんどありません。しかし、10年以上前の論文の思想が、現在も普遍的に生きていることには驚かされます。

Larochelle et al. (2008)の著者には、同じくディープラーニングのゴッドファーザーと呼ばれるヨシュア・ベンジオが含まれています。AI界の伝説の存在が10年以上前に揃ってこのトピックに取り組んだことから考えると、Zero-shotはそれだけに魅力的な問題なのでしょう。

時代はテキストと画像のクロスモーダルへ

クロスモーダルのアプローチ

Zero-shotを「テキスト」と「画像」のクロスモーダルの問題として捉えたのが、Socher et al.(2013)[6]です。これまでのZero-shot LearningもZero-data Learningも、教師データがあるのを前提にしていました。本論文の大きな違いは、あらかじめ用意された大規模なコーパスから、ラベルのテキストの埋め込み量と外れ値検出を組み合わせて、未知の画像のクラスを推定している点です。

イメージとしては、「truck」と「cat」が未知のクラスだとします。猫の画像が与えられた場合、まず既知のクラスに対して「未知である」という外れ値検出します。そして、テキストの埋め込み量を使い、近傍の「猫」である推定します。ポイントは、大規模なコーパスを使っているため、クラスのラベルとしては、トラックも猫も既知であることです。

なお、この論文の著者には、同じく機械学習界の重鎮であるアンドリュー・ンがいます。私も含め、彼のCourseraの講義にはお世話になった方が多いのではないでしょうか。

DeViSE

Socher et al.(2013)[6]をもう少しこなれた形にしたのが、DeViSE[7]です。DeViSEでは、外れ値検出を使わず、画像と単語の類似度を直接学習していきます。後で紹介するCLIPにかなり近いモデルです。

クロスモーダルのアプローチを使ったのはこの論文が初ではありませんが、要旨には「ラベル付き画像だけでは学習データの取得に限界があるから、テキストデータを活用して精度を高めよう」という意図が書かれています。私がZero-shotを調べたときに「なぜ学習データがないタスクに対して、text-supervisedのアプローチがデファクトスタンダードとして扱われているのか」という点が理解できませんでした。おそらくDeViSEのあたりから、このような方向性が確立されてきたのかな考えられます。

DeViSEの論文は、Googleのチームによって書かれたものですが、著者にはジェフ・ディーンがいます。Zero-shotはどれだけ界隈のレジェンドを惹き付けるのでしょうか。

Visual N-Grams

Visual N-Grams[8]は、現在のCLIPのように、インターネットからダウンロードした膨大な画像-テキストの組み合わせに注目した論文です。

自然言語処理では古くからN-Gramモデルが使われていますが、この考え方を画像に転用したものです。例えば、港にあるクレーンを1つとっても、人間は「navy yard」や「サンディエゴの港」のような多面的な認識をします。Visual N-GramsでもZero-shotへの言及はありますが、Visual N-Gramsの論点が現在のZero-shotを俯瞰するには有用だと思います。要旨からです。

Real-world image recognition systems need to recognize tens of thousands of classes that constitute a plethora of visual concepts. The traditional approach of annotating thousands of images per class for training is infeasible in such a scenario, prompting the use of webly supervised data. This paper explores the training of image-recognition systems on large numbers of images and associated user comments, without using manually labeled images.
実世界の画像認識システムでは、膨大な数の視覚的概念を構成する何万ものクラスを認識する必要があります。このようなシナリオでは、クラスごとに数千枚の画像にアノテーションを付けて学習する従来のアプローチは現実的ではなく、Webから学習したデータの利用が必要である。本論文では、人手でラベル付けした画像を用いずに、大量の画像と関連するユーザコメントから画像認識システムを学習する方法を検討する。

ポイントは2点あります。

  • 実世界の画像認識を考えると、(人間のような)何万ものクラスを多面的に認識したい。従来のMany-shotのアプローチはこれに向かない
  • アノテーションは、人がつけるのではなく、Webに付随しているデータ(コメント)から直接利用するべき

つまり、従来のMany-shotの問題設定としての限界、アノテーションを人間がしたくないことの2点が、この背景となる思想です。これを知ることで、現在のZero-shotの流れが理解しやすくなります。

Contrastive Learning

最初は次元削減から

今まで、Zero-shot Learning、クロスモーダルに関する論文をいくつか紹介してきましたが、CLIPを理解するためには「Contrastive Learning」について知っておく必要があります。そもそもCLとはどのようにして生まれたのでしょうか?

Contrastive Learningの原案は、これもディープラーニングがブームになる前から生まれていました。Hadsell et al.(2006)[9]は、「spring system」と呼ばれる次元削減のための学習システムを構築しました。これはもともと、高次元のデータを低次元の多様体にマッピングする次元削減の問題として提案されました。図はこちらの論文からです。

黒い丸は類似サンプル、白は似ていないサンプルです。このように似ているサンプル同士を近づけて、似ていないサンプル同士を離すという、バネのようなシステムであったから「spring system」と呼ばれました。spring systemという呼び方は現在ほとんどされませんが、考え方自体は現在のContrastive Learningと同じです。

なお、この著者にはヤン・ルカンがいます。Zero-shot Learningを追っているだけで、「ディープラーニングのゴッドファーザー」と呼ばれているチューリング賞受賞の3人:ヤン・ルカン、ヨシュア・ベンジオ、ジェフリー・ヒントンの論文をコンプリートできます。

SimCLR

ディープラーニングにおけるContrastive Learningの大きなブレイクスルーになったのがSimCLR[10]です。

2020年の論文で、spring boxからいきなり時代を飛び越えましたが、「サンプル間の類似度を学習し、類似したものを近づけようとする」という大枠は変わっていません。\mathcal{T}はData Augmentationで、同一のデータ\mathbb{x}をData Augmentationして、異なるサンプル\mathbb{\tilde{x_i}, \tilde{x_j}}を得ます。これを同じニューラルネットワークに通して、類似度を学習します。SimCLRはクロスモーダルではなく、画像で完結するためこのようなAugmentationを入れています。

SimCLRは、ディープラーニングの問題としてシンプルな枠組みで実現し、教師なし学習でも教師あり学習に迫る精度を達成したことから、大きな注目を集めました。SimCLRのような枠組みは自己教師あり学習(Self-supervised learning)とも呼ばれます。教師ありとはいうものの、自分自身を教師として使うため、従来の教師あり学習のように人間が作ったラベルは必要ありません。これは、先程紹介したクロスモーダルなモデルと発想が似ています。

SimCLRの著者には、またヨシュア・ベンジオがいます。先程紹介したZero-data Learningから14年越しです。これだけ長い期間、インパクトのある研究を出し続けられるのは本当に驚かされます。

なお、この記事を読んでいる方には「Contrastive Learningと、One-shot Learningの文脈で語られることの多いMetric Learningは何が違うのか?」という疑問も抱いた方もいるかもしれません。私もこの疑問を持っていたのですが、最近ではContrastive LearningにMetric Learningの要素を加えた研究[11]もあり、両者の境界が曖昧になっています。厳密には違いがあるのかもしれませんが、方言の違いぐらいの認識で良いと思います。

CLIP誕生

クロスモーダルなZero-shot Learning

これまで長い期間をかけて、

  • Zero-shot Learning
  • 画像と文章のクロスモーダル
  • Contrastive Learning

の3つを紹介しました。これらはすべて「CLIP[12]」を説明するためのパーツです。いよいよCLIPを見ていきましょう。

CLIPはOpenAIによって2021年に発表された激強論文です。Zero-shotかつクロスモーダルで、従来のImageNetで訓練済みのMany-shotのモデルよりも頑強性が高いことが大きく注目されています。以下の図はCLIPの論文からですが、目にした方も多いのではないでしょうか。

CLIPもContrastive Learningをしています。SimCLRでは、1つの画像にData Augmentationを適用して2種類の画像を作り、その類似度を最大化していました。CLIPでは、これが画像とテキストのクロスモーダルなので、1枚の画像とそれに対応するキャプションの類似度を最大化するようなContrastive Learningになります。画像のキャプションとはどういうものかというと、次のようなものです。これはOpen Images Dataset V6[13]のサンプルで、元画像の作者は[14](CC BY 2.0)です。

「Fractal Cauliflower」というのがキャプションです。インターネット上の画像(例:Flicker、Instagram)には、キャプションが付属していることが多いので、それを活用します。この例では、キャプションを通じて「幾何学的」と「カリフラワー」の2つの概念を学習できます。

ただし、テキストと画像の特徴量は同じネットワークでは計算できないため、それぞれ別のモデル(例:Transformer、Vision Transformer)を使います。例えば、バッチサイズをN、特徴量の次元をDとすると、画像とテキストネットワークの出力特徴量は、それぞれ(N, D)次元で表されます。これらの行列積を取り出し、(N, N)というコサイン類似度の行列を作り、対角成分を正例、それ以外を負例としてContrastive LearningするのがCLIPです。

この図も論文からの引用です。ポイントは、コサイン類似度の実装をL2-Normalizeで行われていることで、ベクトルa, bに対するコサイン(類似度)というのは、

\cos(a, b)=\frac{a\cdot b}{|a| |b|} = \frac{a}{|a|}\cdot\frac{b}{|b|}

で表されます。これは高校数学の教科書にも載っている公式ですが、右辺はL2-Normalizeそのものです。行列に拡張したのが上図のコードです。高校数学でおなじみの公式が、最先端の研究のキー技術になっているのは、学校で教えて欲しいぐらいです(機械学習で高校数学がいるよ、と言われるのはこういう理由です)。

ロジットの計算の部分で、さらっと温度付きの計算して、さらに学習パラメーターとするという面白いことをやっているのですが、それはおいておきましょう。

プロンプトエンジニアリング

CLIPでは、訓練時に画像の持っているキャプションを活用しましたが、推論時はどう考えているのでしょうか。推論時の画像はキャプションを持っていない場合も多いです(例:カメラから撮った画像)。CLIPではプロンプト(prompt)という、自然言語処理由来の独特な概念が出てきます。この図はOpenAIのブログ記事からです[15]

飛行機を推定するには、「a photo of a airplane / bird / bear / ...」といろいろなテキストがありますよね。このようにターゲットとなるテキストを複数提示して、最も類似度の高いテキストを選択するのが、CLIPの推論方法です。ここで「a photo of {label}」というテンプレートが、プロンプトです。プロンプトという単語は聞き慣れませんが、Windowsに出てくる「コマンドプロンプト」を連想すると馴染み深いのではないでしょうか。プロンプトを調整して精度を上げていくのがプロンプトエンジニアリングです。機械学習でおなじみの「特徴量エンジニアリング」の仲間として捉えると理解しやすいかもしれません。

CLIPでのプロンプトエンジニアリングは、この他に「A photo of a big {label}」「A photo of a small {label}」というように、複数のプロンプトを組み合わせてアンサンブルします。これが精度の改善に大きく寄与します。CLIPの論文からです。

プロンプトエンジニアリングとアンサンブルにより、ImageNetのZero-shot精度を5%、モデル計算量で4倍改善できたとのことです。ここはこの記事本来の目的ではないので、細かくは書きませんが、CLIPには「プロンプト」というクロスモーダル特有の概念が出てくるよ、ということを覚えておいてください。やっていることは単なるテンプレート構文です。

CLIPとZero-data Learnig、Visual N-Grams

今回は「CLIP」の紹介が主な目的ですが、CLIPより前の古い論文もいくつか掲載しました。その理由は、10年以上も前の古い論文で培われた思想が、CLIPでも生きていることを味わってほしかったのと、CLIPがどういった理念から作られたのかを知ってほしかったからです。私個人が純粋に興味あったからというのもあります。

さて、最初にZero-data Learningの論文を紹介しました。CLIPはZero-shot Learningでありながら、実はZero-data Learningの要素もかなり含んでいます。Zero-data Learningを再度引用してみましょう。

We introduce the problem of zero-data learning, where a model must generalize to classes or tasks for which no training data are available and only a description of the classes or tasks are provided.
ゼロデータ学習とは、学習データがなく、クラスやタスクの説明のみが提供されているクラスやタスクに対してモデルを汎化しなければならない問題である。

単に学習データがないというと誤解を招きますが、CLIPにはMany-shotのようなタスク固有の学習データがないのは事実です。ユーザーがFine-tuningしなくても、学習済みモデルに、飛行機の画像とプロンプトを与えれば、飛行機だと推論できます。「クラスとタスクの説明のみ提供されている」という点はまさにプロンプトです。

つまり、CLIPは「訓練データにはないクラスを認識する」というZero-shot Learningの文脈にありながら、Zero-data Learningをまるで伏線回収のように取り込んでいるのです。13年越しにこのような回収して、大きなブレイクスルーを導き出したのはかなり痺れるものがあります。

Visual N-Gramsは、CLIPの論文内でも比較対象として言及されるほど、強く意識していた論文です。Visual N-Gramsのポイントを再掲します。

  • 実世界の画像認識を考えると、(人間のような)何万ものクラスを多面的に認識したい。従来のMany-shotのアプローチはこれに向かない
  • アノテーションは、人がつけるのではなく、Webに付随しているデータ(コメント)から直接利用するべき

CLIPは、アノテーションの事情もふまえつつ、人間のような多面的かつ何万クラスの認識をしたいという考え方を受け継いでいるわけです。多面的な認識はプロンプトのラベルの単語を変えることで可能にしていますし、アノテーションはクロスモーダルなContrastive Learningによって解決しています。Visual N-Gramsの要旨で書かれたような、実世界の画像認識がCLIPの登場でより現実的になったということが言えるでしょう。

CLIPの欠点

万能のように書いたCLIPですが、実は欠点もあります。それはモデルの訓練に莫大な量のGPUとデータが必要なことです。CLIPの論文からです。

The largest ResNet model, RN50x64, took 18 days to train on 592 V100 GPUs while the largest Vision Transformer took 12 days on 256 V100 GPUs.
最大のResNetモデルであるRN50x64の学習には592台のV100 GPUで18日、最大のVision Transformerは256台のV100 GPUで12日かかりました。

ResNet50×64でV100が592×18=29.2年分も必要です。ResNet50×4に減らしても1回訓練するだけで1.8年分のGPUが必要です。CLIPのような強いモデルをスクラッチから訓練できるのが、ビックテックのような限られた存在になってしまい、技術を寡占されてしまわないかという懸念があります(訓練に億円単位がかかるモデルを自由に利用できるように公開する、仏のような心の持ち主が多数いるのを期待したいです)。

また、必要なデータ数も莫大です。後で説明する後続研究のSLIP[16]によれば、CLIPは4億枚の画像-テキストのデータを使用しており、これを1500万枚に減らしたところViT-BでImageNetのZero-shot精度が37.6%になったと報告されています[17]。300万枚では、同じモデルで17.1%まで落ちてしまいました。Zero-shotの精度を教師ありImageNet並に上げたければ、億単位の膨大なデータを使った訓練が必要という身も蓋もない話になります。

SLIPの論文によれば、ViT-B/16でCLIPを再実装するために、1500万枚に限定しても64台のV100で22.3時間(59.4日分)だったことが記されています。V100を2ヶ月使ってImageNetが40%行けばいい、というのはなかなか考えさせられるものがあります。

CLIPの先にあるもの

激強の論文として紹介したCLIPですが、訓練済みCLIPを使った研究が最近(2022年)とても活発です。CLIPのその後の論文をいくつか紹介していきましょう。

CLIPの改良

まずはダイレクトにCLIPの学習をより効率的にする方法です。1つ目はUC BerkeleyとFacebook AIのチームが2021年に発表したSLIP[16]で、図は論文からです。これはCLIPの画像側に自己教師あり学習を組み合わせ、先程紹介したSimCLRを画像側に追加したものです。CLIPよりも精度が高くなります。

2つ目はICLR 2022で発表されたDeCLIP[18]です。(1)モダリティ内、(2)複数のビューを通じたモダリティ間、(3)類似ペアから最近傍、それぞれの自己教師あり学習します。これにより、CLIPよりも7.1倍少ないデータ数で同じ精度を達成できています。

生成モデルへの応用

訓練済みCLIPと生成モデルは相性が良く、AnyFace[19]ではテキストを使って、ソースの顔写真から自在に合成できる手法です。

DALL・E 2[20][21]は、「馬に乗る宇宙飛行士」のようにテキストを入力すると画像を生成するモデルで、非常に精緻な画像が生成されることで話題になりました。この実装には訓練済みCLIPと拡散モデルが使われています。画像はOpenAIのサイトからです。

生成モデル以外にもCLIPの応用は多数あるのですが、これを紹介するときりがなくなるので割愛します。CLIPの活用がまさに研究の最先端で行われています。

SLIP(CLIPの後続研究)を動かしてみる

記事の締めくくりとして、Zero-shotのモデルを実際に動かしてみましょう。CLIPの推論は有用な記事が多数あるので、ここでは後続研究のSLIP[16]を動かします。CLIPのように使い勝手の良いAPIは整備されていませんが、公式のコード[17]をいじることで比較的簡単に実装できます。

ダウンロード

まずはSLIPの公式リポジトリからソースコードをダウンロードしてきます。PyTorchとtimmが別途必要です。

同じリポジトリから、訓練済みモデル(チェックポイント)をダウンロードします。いくつかモデルがありますが、ここでは「ViT-Base / SLIP / 100Epochs」(0-shot 45.0%)をダウンロードします。「Weights」のurlからダウンロードできます。2GB近いファイルなので、容量に気をつけてください。

先程クローンしたリポジトリの直下に「weights」フォルダを作り、そこに保存します。

チェックポイントの軽量化

ダウンロードしたチェックポイントは、推論に不要なパラメーターも含んでいるので、軽量化します。これにより、1.9GB→651MBに軽量化されます。元のチェックポイントは削除して構いません。

import torch
from collections import OrderedDict
from tokenizer import SimpleTokenizer
import models
from torchvision import transforms
from PIL import Image
import numpy as np

def strip_checkpoint(ckpt_path):
    ckpt = torch.load(ckpt_path, map_location='cpu')
    state_dict = OrderedDict()
    for k, v in ckpt['state_dict'].items():
        state_dict[k.replace('module.', '')] = v

    old_args = ckpt['args']
    print("=> creating model: {}".format(old_args.model))
    model = getattr(models, old_args.model)(rand_embed=False,
        ssl_mlp_dim=old_args.ssl_mlp_dim, ssl_emb_dim=old_args.ssl_emb_dim)
    model.cpu()
    model.load_state_dict(state_dict, strict=True)
    print("=> loaded resume checkpoint (epoch {})".format(ckpt['epoch']))

    torch.save(model, "weights/"+old_args.model)

strip_checkpoint("weights/slip_base_100ep.pt")

推論関数の作成

SLIPにはCLIPのように簡単に利用できる推論APIがないので、それっぽいものを作ります。公式コードをアレンジしたものです。

def load(model_name, device):
    model = torch.load("weights/"+model_name).to(device)

    image_preprocess = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        lambda x: x.convert('RGB'),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
    ])
    tokenizer = SimpleTokenizer()

    return model, image_preprocess, tokenizer

def inference(img_path, templates, classes, device="cuda:0"):
    model, image_process, tokenizer = load(
        "SLIP_VITB16", device)

    image = image_process(Image.open(img_path)).unsqueeze(0).to(device)

    text_features = []
    with torch.no_grad():
        # 言語特徴量
        for label in classes:
            texts = [t.format(label) for t in templates]
            texts = tokenizer(texts).to(device)
            texts = texts.view(-1, 77).contiguous()
            class_embeddings = model.encode_text(texts)
            class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
            class_embeddings = class_embeddings.mean(dim=0)
            class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
            text_features.append(class_embeddings)
        text_features = torch.stack(text_features, dim=0) # (n_classes, 512)

        # 画像特徴量
        image_features = model.encode_image(image)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        # コサイン類似度→ロジット→確率
        logits_per_image = model.logit_scale.exp() * image_features @ text_features.t()
        pred_prob = torch.softmax(logits_per_image, dim=-1).cpu().numpy()
        return pred_prob

img_pathは画像のパス、templatesはプロンプト、classesはクラス名のテキストを表します。プロンプトとクラス名を固定で複数の画像を推論したい場合は、言語特徴量やモデルをキャッシュさせるとよいでしょう。

一番いい推論を頼む

SLIPを使って画像を推論してみましょう。使用するサンプルは「そんな装備で大丈夫か?」で話題になったエルシャダイ[22]です。フリー素材として公開されている[23]ので、ありがたく利用します(画像クレジットはこちらで付与したものです)。

ではイーノックが何を着ているのかSLIPで判定してみましょう。ちゃんと装備をつけていると判定できるでしょうか?

# プロンプト
templates = {
    "a picture of a {}.",
    "a {} in a video game.",
    "a {} in an animation.",
    "a {} in a movie."
}
classes = ["man wearing uniform", "man wearing swimsuits", "man wearing armor", "naked man"]
result = inference("e3_luciferpv_1080p_free/02/e3_luciferpv_1080068.jpg", templates, classes)
print(classes)
print(np.round(result, 3))

プロンプトは4個与えてアンサンブルしています。クラスの候補としては、

  • man wearing uniform(制服を着ている)
  • man wearing swimsuits(水着を着ている)
  • man wearing armor(鎧を着ている)
  • naked man(裸)

の4つを与えました。結果は以下の通りです。

['man wearing uniform', 'man wearing swimsuits', 'man wearing armor', 'naked man']
[[0.121 0.001 0.877 0.   ]]

87.7%の確率で鎧を着ているとなりました。「大丈夫か?」と心配されるような装備でも無事に鎧と認識できました。

Zero-shotの面白いところは、追加の訓練もモデルを変更しなくても、問題の枠組みを変えられるところです。先程は「何を着ているか?」の分類でしたが、今度は「どんな行動しているか?」の分類にしてみます。

イーノックがキックしているシーンです。この写真だけでキックと認識できるでしょうか? 選択肢を次のように与えて同様に推論してみます。

  • man shooting(射撃している)
  • man kicking(キックしている)
  • man sleeping(寝ている)
  • man eating(食べている)
['man shooting', 'man kicking', 'man sleeping', 'man eating']
[[0.16  0.818 0.018 0.004]]

81.8%で蹴りと判定できました。shootingは攻撃モーションから、近いと判断されたのかもしれません、行動認識と組み合わせると面白そうですね。

まとめ

本記事では、クロスモーダルでZero-shot LearningのブレイクスルーであるCLIPが、過去の論文からどういう思想のもとで生まれたのかを俯瞰し、Zero-shotのモデルとしてSLIPの使い方を紹介しました。ディープラーニングのゴッドファーザー達が過去に生み出したアイディアが脈々と受け継がれ、CLIPによって一気に開花し、まさに今Zero-shotかつクロスモーダルなモデルが、社会実装へと向かいつつあります。このダイナミズムをぜひ一緒に体感しましょう。

引用

  1. Koch, Gregory, Richard Zemel, and Ruslan Salakhutdinov. "Siamese neural networks for one-shot image recognition." ICML deep learning workshop. Vol. 2. 2015.
  2. Palatucci, Mark, et al. "Zero-shot learning with semantic output codes." Advances in neural information processing systems 22 (2009).
  3. Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton. "Imagenet classification with deep convolutional neural networks." Advances in neural information processing systems 25 (2012).
  4. Tom Mitchell. "Machine Learning." McGraw Hill. 1997.
  5. Larochelle, Hugo, Dumitru Erhan, and Yoshua Bengio. "Zero-data learning of new tasks." AAAI. Vol. 1. No. 2. 2008.
  6. Socher, Richard, et al. "Zero-shot learning through cross-modal transfer." Advances in neural information processing systems 26 (2013).
  7. Frome, Andrea, et al. "Devise: A deep visual-semantic embedding model." Advances in neural information processing systems 26 (2013).
  8. Li, Ang, et al. "Learning visual n-grams from web data." Proceedings of the IEEE International Conference on Computer Vision. 2017.
  9. Hadsell, Raia, Sumit Chopra, and Yann LeCun. "Dimensionality reduction by learning an invariant mapping." 2006 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR'06). Vol. 2. IEEE, 2006.
  10. Chen, Ting, et al. "A simple framework for contrastive learning of visual representations." International conference on machine learning. PMLR, 2020.
  11. Chen, Shuo, et al. "Large-margin contrastive learning with distance polarization regularizer." International Conference on Machine Learning. PMLR, 2021.
  12. Radford, Alec, et al. "Learning transferable visual models from natural language supervision." International Conference on Machine Learning. PMLR, 2021.
  13. https://storage.googleapis.com/openimages/web/index.html
  14. https://www.flickr.com/people/tristanf/
  15. https://openai.com/blog/clip/
  16. Mu, Norman, et al. "SLIP: Self-supervision meets Language-Image Pre-training." arXiv preprint arXiv:2112.12750 (2021).
  17. https://github.com/facebookresearch/SLIP
  18. Li, Yangguang, et al. "Supervision exists everywhere: A data efficient contrastive language-image pre-training paradigm." arXiv preprint arXiv:2110.05208 (2021).
  19. Sun, Jianxin, et al. "AnyFace: Free-style Text-to-Face Synthesis and Manipulation." arXiv preprint arXiv:2203.15334 (2022).
  20. https://openai.com/dall-e-2/
  21. Ramesh, Aditya, et al. "Hierarchical Text-Conditional Image Generation with CLIP Latents." arXiv preprint arXiv:2204.06125 (2022).
  22. http://elshaddai.jp/elshaddai_crim/index.html
  23. http://elshaddai.jp/elshaddai_crim/freedeta.html

AI王 〜クイズAI日本一決定戦〜 第2回コンペティション 振り返り

こんにちは、エクサウィザーズNLPギルド所属の神戸です。

本記事では、3月に終了しました「AI王 〜クイズAI日本一決定戦〜 第2回コンペティション」の振り返りとなります。 NLPギルドのチームで参加しまして、3位入賞という結果になりました!

要点

  • コンペのタスクやルール、配布データについてなどの概要を説明しています
  • コンペで3位に入賞した投稿システムの説明資料を公開しましたので、質問応答システムにご興味のある方はご参考ください
  • コンペ中に複数のWebサイトからクローリング&スクレイピングして独自に作成したデータセットも公開しましたので、ご活用ください

コンペ概要

  • 日本の(日本語を対象とした)質問応答研究を促進させることを目的としています。クイズ問題を題材とした質問応答データセットを用いてクイズに解答するAIを開発するコンペとなっています。
  • 去年に第1回目コンペが開催され、今年が第2回目コンペとなります
  • 東北大学、理化学研究所、NTTの共同プロジェクトで運営をされています。

コンペルール

ルールは大まかに以下のようなルールがありました。Wikipediaのナレッジソースなども含めた形式でDockerファイルを提出するようになっておりこの辺りはよくみるコンペとは異なっているところかと思います。

  • 情報源,モデル等含めて圧縮済みで30[GB]以内、実行環境を含むDockerイメージを提出 + 実行時間 6h以内
    • 評価値 = 1,000問の正答率 (配布学習データ 22,335問)
  • 暫定評価:⽂字列の完全⼀致(Exact Match)で確認
  • 最終評価:⼈⼿で表記の揺れなども考慮して確認
  • Wikipediaを含め,⼀般公開されている, もしくは公開できるデータのみ利⽤可能
  • 外部リソース(インターネット検索など)は利⽤禁止

また、第2回コンペでは,第1回のコンペと異なり選択肢を排除し,あらゆる解答がありえるという,より通常のクイズ大会に近い設定となっていました。つまり,問題として与えられるのはクイズの問題文のみとなっていました。その問題文のみから解答となる文字列を解答として返すシステムを構築する必要がありより汎用的&難易度の高い設定になっていました。

引用: 加藤拓真, 宮脇峻平, 第二回AI王最終報告会 - DPR ベースラインによる オープンドメイン質問応答の取り組み (2022) - Speaker Deck

配布データ

公式からは以下のデータが配布されています。

また、学習用データのサンプルのフォーマットは以下のようになっています。

{
  "qid": "AIO02-0001",
  "competition": "第2回AI王",
  "timestamp": "2021/01/29",
  "section": "開発データ問題",
  "number": 1,
  "original_question": "映画『ウエスト・サイド物語』に登場する2つの少年グループといえば、シャーク団と何団?",
  "original_answer": "ジェット団",
  "original_additional_info": "",
  "question": "映画『ウエスト・サイド物語』に登場する2つの少年グループといえば、シャーク団と何団?",
  "answers": [
    "ジェット団"
  ]
}

投稿システム説明と独自作成したデータセット

私たちのチームの投稿システム説明資料(最終報告会で発表したものと同じ)と独自に作成したデータセットはアップロードしていますので、こちらをご参照ください。

投稿システム説明資料 drive.google.com

独自に作成したデータセット(データセットについては説明資料の外部データをご参照ください) drive.google.com

作成したデータセットの統計は以下のようになっています。データ数としては76721となっており質問応答のデータセットとしては比較的多いデータ量になっています。 また、Elasticsearchで正例のpassageを付与できたものの数は60443となっていました。

外部データセットの統計

投稿システムの簡潔なまとめは以下となります。

  • Retriever-Reader構成のモデルを使用
    • Retriever: 東北大BERT-Baseモデル(cl-tohoku/bert-base-japanese-whole-word-masking)
    • Reader: 東北大BERT-Large(cl-tohoku/bert-large-Japanese)
  • 正例のpassageをシャッフルして学習
  • 学習したRetrieverのhard negativeなサンプルを学習データに追加して再学習
  • 外部データの活用
    • クイズの杜、みんはや、Mr Tydi、Quiz Works、語壺、Erin、Wiktionary QA、5TQなど外部データを利用
    • 外部データはクローリング&スクレイピングして作成
      • サイトごとにクローリング&スクレイピングするコードを実装して作成しました

また、リーダーボードスコアの遷移の以下のようになっています。

リーダーボードスコアの遷移

結果

最終結果は3位でした!(自動評価2位、人手評価3位) また次回コンペが開催されるとのことなので、次回コンペではより良い結果を目指したいと思います!

最終順位結果

引用: AI王 〜クイズAI日本一決定戦〜 第2回コンペティション 公式サイト

コンペを振り返って

コンペに参加してみて以下のような多くの学びがありました。コンペを通じて得られた知見を今後の業務に活かしていきたいと思います。 また、引き続き今回のようなコンペに参加し外部への情報発信にも取り組んで参りたいと思います。

  • Retriever-Readerの構成の質問応答タスクで、ナレッジソースから関連のあるドキュメントもとってくるところもやるのは初めてだったので学びになった
  • Web上に質問応答タスクに使えそうな外部データが結構あることを知ることができたのも学びだった
  • データ量はやはり重要だと再認識、反面いかに少量のデータで精度を出せるかも重要
    • 1,2位のチームは外部データの利用なく精度が出ていたので、参考にしたい
  • エラー分析より、BERTの文脈理解だけでは解くことができない問題を理解できた
    • エラー分析については投稿システム説明資料をご参照ください

エクサウィザーズでは一緒に働く人を募集しています。中途、新卒両方採用していますので、興味のある方は是非ご応募ください!

hrmos.co

event.exawizards.com

RecSys2021学会参加報告記事

 こんにちは。エクサウィザーズで構造化データギルドに所属し、機械学習エンジニアかつエンジニアリングマネージャーをしている小野です。本記事では2021年に推薦システムの国際会議にヴァーチャル出席しましたので(本当はアムステルダムに行きたかったです。)、一部の内容を共有させていただきます。

概要

f:id:oimokihujin:20220407045453p:plain  今回は、オランダのアムステルダムで2021年9月27日から10月1日までバーチャル&オフライン開催された推薦システムの国際会議であるRecSys2021*1に参加したので、内容を記事にさせていただきました。皆様がご存じのように、推薦システムはすでに私たちの生活に切っても切り離せない存在です。Amazonなどのオンラインストアで商品を眺めていると、隣に出てくるオススメ商品をクリックしてしまうことなどはないでしょうか?それらは皆様の行動履歴や商品分類など様々な情報を駆使し、適切なユーザーに適切なアイテムを推薦することによって、ユーザーの意思決定(この場合は購買意欲)を促進します。

 RecSys 2021の開催で16回目の開催となります。RecSysは4日間の日程で多くの本会議(下図は本会議の日程)と多くのワークショップなどから構成されており、非常にボリュームがある会議となっております。また、本会議とワークショップが同時に開催されており、全ての会議に同時に出席することはできません。しかし、コロナの影響前からRecSysはYoutube*2で会議内容を発信しており、当日に参加できない場合でも、後日会議内容を確認することができます。RecSys2021も同様に開催後半年を経て会議内容がYoutubeに公開されているので、当日参加できなかった人も現在は確認することができます。

f:id:oimokihujin:20220407045612p:plain

ワークショップから見る会議の全体感について

 多くの推薦システムの根幹となる技術は機械学習です。昨今、機械学習でも話題となっていることは、推薦システムでも同様に話題となっております。大きな話題の一つとして、機械学習の社会的責任性が挙げられます。この社会的責任性は、公平性、透明性、責任性となります。それぞれを簡単に説明すると、公平性は、誰が使っても不公平ではない推薦結果を出すことができるかに焦点が当てられています。例えば、クーポン配布推薦システムを構築する際、たまたまデータセットの大部分が男性だった場合、男性に偏ったクーポン配布が実施されてしまう可能性があります。透明性はその名が示す通り、推薦システムが推薦した理由が明確にわかる必要があります。例えば、推薦システムがある男性にコーヒーを推薦した場合、その理由を明示する必要があります。責任性は、推薦システムの責任性を示します。例えば、健康器具を推薦するシステムがあるとします。この健康器具を推薦しするシステムが腰痛持ちの人に腰痛を悪化させるような推薦をしてしまい、結果としてユーザーが腰痛を悪化させてしまった場合、推薦システムがどうしてそのような推薦をしたのかを明らかにしなければなりません。

 RecSysでは、推薦システムの社会的責任性を議論するためのワークショップを開催しており、FAccTRec: Workshop on Responsible Recommendationで中心的に議論されています。このワークショップは5年目であり、比較的新しいワークショップであることがわかります。オーガナイザーには産総研の神嶌先生*3の名前もあります。

各ワークショップについて

 各ワークショップでは、発表タイトルが4〜8つ程あり、それぞれが15分ほどの発表時間を持ち発表する形式でした。ここではいくつかのワークショップの概要を説明し、言及したいワークショップについては少し詳しく説明したいと思います。

f:id:oimokihujin:20220407045700p:plain

  1. CARS: Workshop on Context-Aware Recommender Systems
    1. 次世代コンテキストアウェア推薦システムを議論するためのワークショップです。ここでいうコンテキストアウェアとは、ユーザーやアイテムの付帯情報を指します。具体的には、ユーザー特徴量(性別・住所・年齢・など)やアイテムの特徴量(カテゴリ・値段・色など)を指し、これらの情報に加えて、あるユーザーがあるアイテムを購入した情報(購入履歴や閲覧履歴など)を用いて推薦システムを構築します。
  2. ComplexRec: Workshop on Recommendation in Complex Environments
  3. FAccTRec: Workshop on Responsible Recommendation
  4. FashionxRecSys: Workshop on Recommender Systems in Fashion and Retail
  5. GReS: Workshop on Graph Neural Networks for Recommendation and Search
    1. グラフベースのモデルを用いて推薦システムを議論するためのワークショップです。グラフベースのモデルとは、知識グラフなどを用いた推薦システムを指し、アイテム-アイテム間、アイテム-ユーザー間、ユーザーーユーザー間などに存在する関係を明示的に取り扱うことができる推薦システムです。グラフベースのモデルを使うことによって、より高次の情報(商品A→購入者B→商品C→購入者Dなど)や関係を取り込むことができ、精度の面で優れているとされています。
  6. INRA: Workshop on News Recommendation and Analytics
  7. IntRS: Joint Workshop on Interfaces and Human Decision Making for Recommender Systems
    1. 推薦システムの精度やモデルを議論する場ではなく、デザインやインターフェースを議論するワークショップです。例えば、推薦システムで20位までのオススメ商品の結果を表示する際に、2ページで表示する際に1ページ目と2ページ目でバイアスが含まれてしまうなどの弊害があります。また、ユーザーが推薦結果に納得感を得るためのインターフェースが紹介されていました。
  8. KaRS: Workshop on Knowledge-aware and Conversational Recommender Systems
  9. MORS: Workshop on Multi-Objective Recommender Systems
  10. OHARS: Workshop on Online Misinformation- and Harm-Aware Recommender Systems
    1. 間違った情報や人を傷つける情報を察知する推薦システムを議論するワークショップです。SNSなどの情報発信ツールが普及する現在では、間違った、または他人を傷つけるようなコンテンツなどの普及速度も非常に早くなっている。このようなコンテンツをいち早く検出し、他人に推薦しないようなシステムの構築を目指します。特に、現在のcovid-19が流行っている世の中では正しい情報をなるべく早く広めるためにもこのような推薦システムが重要となります。
  11. ORSUM: Workshop on Online Recommender Systems and User Modeling
    1. オンライン推薦システムを議論するワークショップ。ニュースの記事やコメントやユーザーのフィードバックなど、時間が進めば進むほどユーザーに対するコンテキストや情報が蓄積していきます。それらの新鮮な情報をいち早く取り入れ、より良い推薦システムを構築することで、「あの時」は欲しかったのに、「今はもういらない」とならないように「その時」欲しいものを「その時」推薦するシステムの構築を目指します。
  12. PERSPECTIVES: Workshop on Perspectives on the Evaluation of Recommender Systems
    1. 推薦システムの評価に注目したワークショップ。
  13. PodRecs: Workshop on Podcast Recommendations
  14. RecSys Challenge Workshop
  15. RecSys in HR: Workshop on Recommender Systems for Human Resources
    1. 人事部門に関する推薦システムに関するワークショップです。PWCによると、国際企業のHR機能の40%以上がAIアプリケーションを使用しているらしく、特に、優秀人材のスクリーニングなど簡易的に評価するために用いられることが多いそうです。そのような場面では個人に関わる非常にセンシティブなデータを使用するため、推薦システムの社会的責任性が重要となります。
  16. RecTour: Workshop on Recommenders in Tourism
  17. SimuRec: Workshop on Synthetic Data and Simulation Methods for Recommender Systems Research
  18. XMRec: Workshop on Cross-Market Recommendation

参加しての所感

 本記事ではそれぞれの内容の詳細に触れるのではなく、RecSys2021のワークショップに焦点を当てて参加報告を記しました。会議内容の詳細はYoutubeで確認することができるので気になる点を重点的に確認していただくと理解が深まるかと思います。RecSys2022はシアトルで開催されることが決まっています。コロナの状況にもよりますが、次回は現地参加できればと思います。

Load testing with Artillery.io

Introduction

This article is going to be about Artillery, a popular load and smoke testing framework.

Recently I used Artillery to evaluate the performance of some of our production services. I'd like to present some of the scenarios I encountered, and ways to solve them. So if you're new to load testing, I hope this article can serve as a helpful introduction to Artillery.

Now let's get started!

Regarding the code samples

Note that in the below samples, everything will be installed into a local folder we create. So you can follow along and run all of these samples without needing to install anything on your machine globally. So there's no worry about side-effects or changes to the configuration on your system, you can simply delete the folder when you are done!

The only prerequisite is to install Node.

JSONPlaceholder (a simple test server)

In these samples, I'm going to be using a publicly-available test REST API service known as JSONPlaceholder as the server. The public version is available at https://jsonplaceholder.typicode.com/, but instead we're actually going to run this same code locally -- because Artillery is designed to put heavy load on the server, we do not want to cause problems for this free and useful service!

Creating and running tests

Installation

Create a local directory that we'll use to install the dependencies and run our tests

mkdir load-testing
cd load-testing

Install Artillery (and also module csv-parse which we'll need later)

npm install --save artillery
npm install --save csv-parse

Install JSONPlaceholder

npm install --save jsonplaceholder

(Note: you might get some warnings here about your Node version being too new, but you can ignore those. I used Node 15 without problems)

Run JSONPlaceholder server

node ./node_modules/jsonplaceholder/index.js

Running the first test sample

Now that our server is running, let's get our first test code up and running!

# load-testing.yml

config:
  target: "http://localhost:3000"
  phases:
    - duration: 60
      arrivalRate: 10
      name: "Run queries"

scenarios:
  - name: "Run queries"
    flow:
      - get:
          url: "/todos/1"

Let's see what we've got here:

  • We set the location of the server with target: http://localhost:3000
  • In the phases: section we configure to for 60 seconds with 10 simulated users
  • In the scenario "Run queries" we make a GET request to one of the endpoints (this loops until the time is up)

To run it:

./node_modules/artillery/bin/run run load-testing.yml

Reading test cases from a CSV file (payload files)

The above is well and good so far, but we're just requesting the same data from the same resource repeatedly. For most systems this will allow every request to be served from cached code and data, so it isn't a very good simulation of real-world usage. Therefore we'd like to vary the resources and parameters to provide more realistic testing, but it would be a bit unwieldy to hard code each value into the YAML file. This is where "payload" files come in -- we store these parameters in a CSV file, so we can easily create and change test cases without needing to modify the code.

Let's add the CSV file and the related code:

# queries.csv

resource,queryparam1,queryparam2,queryparam3
posts,_start=20,_end=30,
posts,views_gte=10,views_lte=20,
posts,_sort=views,_order=asc,
posts,_page=7,_limit=20,
posts,title=json-server,author=typicode,
comments,name_like=alias,,
posts,title_like=est,,
posts,q=internet,,
users,_limit=25,,
users,_sort=firstName,_order=desc,
users,age_gte=40,,
users,q=Sachin,,
# load-testing.yml

config:
  target: "http://localhost:3000"
  payload:
    path: "queries.csv" # path is relative to the location of the test script
    skipHeader: true
    fields:
      - resource
      - queryparam1
      - queryparam2
      - queryparam3
  phases:
    - duration: 60
      arrivalRate: 10
      name: "Run queries"

scenarios:
  - name: "Run queries"
    flow:
      - get:
          url: "/{{ resource }}?{{ queryparam1 }}&{{ queryparam2 }}&{{ queryparam3 }}"

Now we have the parameters in the CSV. In the payload: section we define the location of the file and variable names for each field, then in Run queries we use these variable names. The nice thing is that Artillery will advance to the next CSV row each time automatically for us!

Creating an initial test data set

With the test server we've been using the data is just static JSON, so it's easy to make every test run start out with a consistent dataset. When testing real services however, you may need to use an API to populate the initial data. Fortunately, it is possible to do this in Artillery without needing additional external tools -- we can use a "processor" (custom Javascript plugin) and put this into the before block (initialization code which runs before the test cases).

// utils.js

const fs = require("fs")
const parse = require('csv-parse')

function loadCsvIntoJson(context, events, done) {
    
    fs.readFile(context.vars['csvFilePath'], function (err, fileData) {
        parse(fileData, {columns: false, trim: true}, function(err, rows) {
            // CSV data is in an array of arrays passed to this callback as `rows`
            context.vars['csvRows'] = rows
            context.vars['row'] = 1

            done()
        })
    })
}

function getNextRow(context, events, done) {
    let row = context.vars['row']

    context.vars['userId'] = context.vars['csvRows'][row][0]
    context.vars['id'] = context.vars['csvRows'][row][1]
    context.vars['title'] = context.vars['csvRows'][row][2]
    context.vars['completed'] = context.vars['csvRows'][row][2]

    row++
    context.vars['row'] = row

    done()
}

function hasMoreRows(context, next) {
    return next(context.vars['row'] < context.vars['csvRows'].length)
}
# load-testing.yml
config:
  target: "http://localhost:3000"
  processor: "./utils.js"
  variables:
    csvFilePath: "todos.csv" # Path is relative to the location of the test script
  payload:
    path: "queries.csv"
    skipHeader: true
    fields:
      - resource
      - queryparam1
      - queryparam2
      - queryparam3
  phases:
    - duration: 60
      arrivalRate: 10
      name: "Run queries"

before:
  flow:
    - log: "Adding Todos..."
    - function: "loadCsvIntoJson"
    - loop:
      - function: "getNextRow"
      - log: Inserting Todo (id={{ id }})
      - post:
          url: "/todos"
          json:
            userId: "{{ userId }}"
            id: "{{ id }}"
            title: "{{ title }}"
            completed: "{{ completed }}" 
      whileTrue: "hasMoreRows"


scenarios:
  - name: "Run queries"
    flow:
      - get:
          url: "/{{ resource }}?{{ queryparam1 }}&{{ queryparam2 }}&{{ queryparam3 }}"

Using .env (dotenv) for configuration

Until now these examples have simply hard-coded many values, but in a real-world test automation setup we probably want to separate configuration from code (many teams use .env as a standard place to store secrets) For our setup, .env was the way to go, but Artillery doesn't support this itself. Fortunately there is a tool called dotenv-cli which can run any arbitrary executable with the variables from .env loaded into its environment. You can install this by running

npm install --save dotenv-cli

For example, we might put the location of the server into our .env file:

# .env
ARTILLERY_TARGET=http://localhost:3000

Then we can load this from the environment in the yaml file:

# load-testing.yml
config:
  target: "{{ $processEnvironment.ARTILLERY_TARGET }}"
  ...

Finally, run with dotenv-cli to use the .env values in the tests:

./node_modules/dotenv-cli/cli.js ./node_modules/artillery/bin/run run load-testing.yml

Interpreting the Output

After the test run completes, you will get some information like this:

All VUs finished. Total time: 1 minute, 7 seconds

--------------------------------
Summary report @ 17:59:44(+0900)
--------------------------------

http.codes.200: ................................................................ 600
http.request_rate: ............................................................. 10/sec
http.requests: ................................................................. 600
http.response_time:
  min: ......................................................................... 13
  max: ......................................................................... 202
  median: ...................................................................... 104.6
  p95: ......................................................................... 147
  p99: ......................................................................... 179.5
http.responses: ................................................................ 600
vusers.completed: .............................................................. 600
vusers.created: ................................................................ 600
vusers.created_by_name.Run queries: ............................................ 600
vusers.failed: ................................................................. 0
vusers.session_length:
  min: ......................................................................... 15.6
  max: ......................................................................... 339.2
  median: ...................................................................... 111.1
  p95: ......................................................................... 156
  p99: ......................................................................... 228.2

Most of these are pretty self-explanatory, but the meaning of "p95" and "p99" might not be immediately obvious. From the documentation:

Request latency is in milliseconds, and p95 and p99 values are the 95th and 99th percentile values (a request latency p99 value of 500ms means that 99 out of 100 requests took 500ms or less to complete).

You may also see lines like:

errors.ETIMEDOUT: .............................................................. 9412
errors.ESOCKETTIMEDOUT: ........................................................ 30
errors.ECONNREFUSED: ........................................................... 16550

These are socket level errors where the client couldn't connect to the server. As you increase the number of users and requests, you'll eventually reach a limit where the service cannot process all of the incoming requests.

Authorization - when the api requires an access token

In our case, our API server requires an authentication token. You can add this to the HTTP headers for requests (where access_token is the token returned by your authentication function):

- post:
    url: "/path/to/resource"
    headers:
      authorization: Bearer {{ access_token }}

Other Resources

JSONPlaceholder (the free server we used) is based on a framework called JSON Server, which is an extremely powerful tool that allows you to create a mock REST server from any arbitrary JSON in just a few minutes! It can very useful for development and testing.

Conclusion

That's it for this article! I hope you found it useful and I encourage you to check out the Artillery Docs if you are interested to learn more!

Kaggle「chaii - Hindi and Tamil Question Answering」コンペで2位入賞したお話 & 解法解説

こんにちは、エクサウィザーズで自然言語処理ギルドに所属している神戸です。(ギルド制についてはこちら

今回、AI/機械学習を用いたデータ分析技術の国際的なコンペティションプラットフォームKaggle上で2021年8月 ~ 2021年11月まで開催されていたchaii - Hindi and Tamil Question Answeringコンペ(略: chaiiコンペ)に参加し、私を含む "tkm kh" というチームで943チーム中2位に入賞&金メダルを獲得することが出来ました。

f:id:kambehmw:20211207115920p:plain
Private Leaderboardの結果

同時に今回の金メダルの獲得で私(kambehmw)はKaggle Competitions Masterに、チームを組んでくださったtkm2261さんはKaggle Competitions Grandmaster(GM)に昇格しています。tkm2261さんは、さすがGMという実力の方でコンペ中色々なことを学ばさせていただきました。チームを組んでいただきありがとうございました!

この記事では、今回のコンペでの私たちのチームの解法について紹介させていただきたいと思います。 また、KaggleのDiscussionに既にチームの解法については共有されておりますが、こちらも必要に応じてご参照していただければと思います。 https://www.kaggle.com/c/chaii-hindi-and-tamil-question-answering/discussion/287917

解法の要点

chaiiコンペで精度に大きく寄与した工夫は以下の3点になります。コンペのタスク概要を説明した後、以下の3点について、それぞれ順に詳細に説明いたします。

  1. 外部データの利用
  2. アンサンブル(XLM-R, Rembert, InfoXLM, MuRIL)
  3. デーヴァナーガリー数字 (०१२३४५६७८९)の後処理

コンペのタスク概要

今回のコンペのタスクはいわゆるオープンドメイン質問応答タスク(Open-Domain Question Answering)であり、以下の入出力のペアが与えられているデータセットになっていました。 質問文とコンテキスト文は対応するペアのものが与えられている設定で、コンテキスト文を参照することで質問文に対応する解答を予測することが求められていました。

入力

  • 質問文(Question)
  • コンテキスト文(Context)

出力

  • 質問に対するコンテキスト文の解答範囲(Answer Span)

f:id:kambehmw:20211207171110p:plain
質問応答タスク イメージ図

また、データセットの言語は、ヒンディー語とタミル語となっており以下のような難しさがありました。

  • 英語のようなリソースが多い言語に比べて利用できる外部データ、学習済みモデルが少ない
  • 読むことができない言語のため、EDAやエラー分析が大変
    • Google翻訳で適宜翻訳して、EDAやエラー分析を実施した

コンテキスト文は、ヒンディー語 or タミル語のWikipedia記事であり、そのため質問もWikipediaに書かれるような一般的な知識を問うようなものになっていました。

コンペのスコア評価はword-level Jaccard scoreによって計算されました。 具体的な実装はコンペの評価ページより以下になっています。

def jaccard(str1, str2): 
    a = set(str1.lower().split()) 
    b = set(str2.lower().split())
    c = a.intersection(b)
    return float(len(c)) / (len(a) + len(b) - len(c))

補足:オープンドメイン質問応答タスク

ちなみに、オープンドメイン質問応答タスクは以下の3つの枠組みのいずれかで解くことが最近主流となっています。

f:id:kambehmw:20211207123440p:plain

引用: https://lilianweng.github.io/lil-log/2020/10/29/open-domain-question-answering.html

今回のコンペでは解答が書かれているコンテキスト文は正解のものがあらかじめ与えられている設定でしたが、解答に関連したドキュメントも検索する必要がある場合は図のRetriverに相当したコンポーネントも実装する必要があります。(今回コンペでは、左下のReader部分のみ実装すれば良かった)

Retrieverは質問に解答するのに必要な情報をExternal Knowledge(例: Wikipedia)から抽出します。External KnowledgeがWikipediaの場合、RetrieverはBM25やDPR[1]といった手法を使用して質問に関連したドキュメントを抽出するのが最近の論文で行われることが多いです。

また、解答範囲の位置を予測するReaderではなく、質問文&コンテキスト文を入力して解答をテキスト生成的に予測するRetriever-Generator[2]の手法や、質問文のみを入力に直接解答を予測してしまうGenerator[3]の手法なども提案されています。Generatorの手法としてはT5などのパラメータを非常に多く持った大規模言語モデルが使用されており、これらのモデルにおいてはExternal Knowledgeを参照せずとも解答に必要な情報をある程度のレベルまで記憶していることが報告されています。

外部データの利用

外部データの利用として、私たちのチームは以下の2つの外部データを利用しました。

  • MLQA[4]
  • TyDi QA[5]に含まれる他のインドの言語を入れる(ベンガル語とテルグ語)

MLQAには下の画像のように7つの言語のデータがありますが、このうちヒンディー語(hi)のデータを一緒に学習することで精度向上しました。

f:id:kambehmw:20211207152838p:plain

また、TyDi QAは英語以外に10種類の言語のデータがあり、このうちのベンガル語とテルグ語のデータを一緒に学習することで精度向上しました。 特にTyDi QAと学習することで、Public Leaderboardのスコアは0.787 -> 0.799に改善しました。

推測にはなりますが、TyDi QAで精度改善した理由を私たちは以下のように考えています。

  • 同じインドで使用されている言語のため類似性がある
  • TyDi QAのデータと質問が重複している

f:id:kambehmw:20211207153627p:plain

アンサンブル(XLM-R, Rembert, InfoXLM, MuRIL)

chaiiコンペの前にあったCommonLitコンペの上位解法から、NLPコンペでアンサンブルは重要だと考えていました。 以下が私たちのチームが使用したモデルになります。モデルは全てlargeサイズを使用しました。

また、アンサンブルに関連した情報は以下です。

  • AlexKay/xlm-roberta-large-qa-multilingual-finedtuned-ruのモデルはロシア語の質問応答データセットでfine-tuningされたモデルなのですが、アンサンブルに寄与しました。
  • モデルを上から順にアンサンブルしていくと、次のように次第にスコアが改善しました。(0.799 -> 0.816 -> 0.821 -> 0.827 -> 0.829)
  • モデルごとにトークンの分割が異なるので、charレベルで予測の出力を合わせてsoftmaxで値を正規化するという操作を入れました
  • シード値を変えてモデルを学習したランダムシートアンサンブルもしています

デーヴァナーガリー数字 (०१२३४५६७८९)の後処理

後処理として、選択された解答の文字列がデーヴァナーガリー数字 (०१२३४५६७८९)だけで構成されていた場合にアラビア数字 (0-9)に置き換えるという処理を入れました。 これは、年(year)が解答になっている質問に対してアノテータがアラビア数字を使ってアノテーションするように一致していたからのようです。 後処理なしのスコアが0.806で、後処理によって0.02ほどの改善をしました。

後処理の実装の詳細について知りたい方は、以下のコードをご参照ください。 https://www.kaggle.com/tkm2261/tkm-tydi-rem-info-muril-xtream-xquad

Cross Validationについて

ローカルでのCross Validationの評価についてはチームで何もせず、Public Leaderboardのスコアに対してチューニングをしていきました。ローカルで評価をしなかった理由は以下になります。

  • Cross ValidationとPublic Leaderboardのスコアに全然相関がなかった
  • trainデータのアノテータは1名、testデータのアノテータは3名で行なったとDataページに説明があったので、trainデータの方にはノイズがありCross Validationの信頼性が低いと推測できた
    • データを目視確認すると確かにアノテーションのブレがあった。

Public Leaderboardのスコア変化

  • 0.787: チームマージ
  • 0.799: TyDi QA (ベンガル語とテルグ語)を用いて学習
  • 0.816: RembertとInfoXLMをアンサンブル
  • 0.821: より多くのXLM-R, Rembert, InfoXLMモデルをアンサンブルに追加
  • 0.827: アンサンブルの重みをチューニング
  • 0.829: MuRILモデルをアンサンブルに追加
  • 0.831: 後処理のパラメータチューニング

順位変動 (Shake) について

今回のchaiiコンペはPublicとPrivate Leaderboardの順位変動が大きかったコンペでした。どの程度順位変動があったかをプロットしてくださったNotebookがありましたので参考までに示します。

f:id:kambehmw:20211207171930p:plain
PublicとPrivate Leaderboardの順位変動
引用: https://www.kaggle.com/c/chaii-hindi-and-tamil-question-answering/discussion/287960

左上のPublic Leaderboard 550位くらいから、Private Leaderboardで金圏に入られていた方もいたので比較的順位変動が大きくあったコンペだったと言えるのではないでしょうか。 私たちのチームがShakeしなかった理由は以下であると個人的に考えています。

  • アンサンブルで異なるモデルを使うことで多様性を上げることができた
  • 外部データを活用することで、データ量を増やすことができた
  • testデータの方がラベルが正確だと信じることができ、Trust LB(Leaderboard)戦略が今回のコンペでは正解だった

精度に効かなかったこと

chaiiコンペを振り返って

今回のコンペは業務であまり経験のない質問応答というタスクでしたが、コンペを通じて多くの知見を得ることができました。コンペデータとは異なる言語データを含めて学習することで精度改善が見られたので、日本語についても言語的に近い他言語のデータを活用することで精度改善に役立つのではと思いました。また、RemBertやInfoXLMといった今まで試したことがなかったマルチリンガルモデルも精度的に役立つことがわかりました。コンペを通じて得られた知見を今後の業務に活かしていきたいと思います。

今回のコンペでKaggle Competitions Masterになることはできましたが、引き続き技術力を高めていくことでエクサウィザーズのビジネスにより一層貢献することを目指していきたいと思います。

エクサウィザーズでは一緒に働く人を募集しています。中途、新卒両方採用していますので、興味のある方は是非ご応募ください! hrmos.co event.exawizards.com

参考文献

Kaggle初参加振り返り〜Shopeeコンペでソロ銀メダル〜

こんにちは。MLエンジニアの川畑です。
今回は、以前から気になっていながらも、中々参加の一歩が踏み出せなかったKaggleについに参加したところ、幸いにも 46th / 2426チームで銀メダルを獲得しましたので、初参加を振り返りたいと思います。 なお、本記事では上位陣の解法の詳細は紹介しませんので、ご興味がある方はKaggleのコンペサイトに投稿されている解法を参照ください(https://www.kaggle.com/c/shopee-product-matching)。

f:id:k_kawabata:20210601120005p:plain

前提

  • 取り組んだ時間
    • 本コンペに参加し始めたのはコンペ終了まで残り1ヶ月を切った頃で、基本的には業務終了後から就寝までの間で平均2時間程度、週4〜5日程度Kaggleに時間を割いていました。自分は休日にダラダラするのが好きなので、どちらかというと平日の方が取り組んでいました。また、幸いなことにコンペ開催終盤にはGWが重なったので、そのタイミングで比較的時間を費やすことができました。
  • 実験環境
    • 私は自前のGPUを持っていないため、Kaggle notebookとGoogle Colabのみで実験を回しました。1週間に利用できる時間や連続で利用できる時間などに制限があるため、大量の実験や時間のかかる実験を回す必要があるコンペでは、これらの無料リソースだけだとかなりしんどいのではないかと思います。また、Google Colabでは毎回データをコピーしてこないといけなかったり、1日ガッツリ利用すると使用量上限に達し、次の日に制限がかかって利用できなかったりと、少々使いづらい部分もありました。とはいえ、無料でGPUを利用できるのは非常にありがたいことです(有料のGoogle Colab Proを使えばもう少し制限は緩和されます)。

コンペ概要

本コンペは、2021/03/08~2021/05/10にShopeeという東南アジアのeコマースプラットフォームが開催したもので、商品の画像とタイトルから同一商品を検索するという課題でした。

一般に小売企業は、自社の商品が最も安いことをお客様に保証するために、さまざまな方法を用いています。中でも、他の小売店で販売されている商品と同等の価格で商品を提供する「プロダクトマッチング」という手法があります。しかし、同じ商品であっても、小売業者によって使用する画像やタイトルなどが大きく異なることもあり、同じ画像や同じタイトルで単純にマッチングさせるだけでは不十分です。そこで機械学習を使ってマッチングの精度を向上させたい、というのが本コンペのモチベーションと考えられます。実際に、Shopeeでも掲載されている何千もの商品に対して「最低価格保証」機能を提供しており、このコンペでのアプローチが活用されることが想定されます。

提供データ

学習データには34250件の商品が含まれ、テストデータには70000件程度の商品があると記載されていました。

  • posting_id : 投稿ID(各投稿にユニークなID。商品が同じでも投稿者が違えば異なるIDになる)
  • image: 商品画像
  • image_phash: 商品画像のperceptual hash(これは画像を64bitのハッシュ値に変換したもので、同じ画像からは同じ値、また似た画像からは似た値が得られます。ちなみにモデルを作成する上で、この情報は使用しませんでした)
  • title: 商品タイトル(英語とインドネシア語が混ざっていた)
  • label_group: ラベルグループID(これが同じものが同一商品とみなされる)

f:id:k_kawabata:20210601120341p:plain
提供データの例

評価指標

  • F1-score
    • 予測対象は、対象商品と同一の全ての商品の投稿ID(posting_id)。ただし,同一商品には必ず自分自身を含み,上限は50個。
    • 各行(各商品)に対してF1-scoreを計算し、それを全体で平均したもの

以下はサブミッションファイルの例です。各クエリ商品に対して、その商品とマッチする商品の投稿ID全てをスペース区切りのリストで与えます。マッチする商品には必ず自分自身を含むため、matchesの列には、クエリ商品自身の投稿IDが含まれます。以下の例では、 test_123は自分以外にマッチする商品がなく、 test_456は自分以外に test_789とマッチすることを意味します。

posting_id,matches
test_123,test_123
test_456,test_456 test_789

自分の解法

f:id:k_kawabata:20210601120545p:plain
予測パイプライン図

上が私の解法のパイプライン図です。大きな流れとしては、画像とテキストそれぞれでモデルを作成し、それぞれから得られた埋め込み表現を元にkNNでマッチング商品を予測し、最後に和集合をとる、というものです。上記パイプライン図におけるTF-IDF以外の4つのモデルに対しては、損失関数としてMetric learning(距離学習)で使われるArcFace loss[1]を使用しました。距離学習について簡単に説明すると、埋め込み空間において、同じクラスのものはなるべく近づけ、異なるクラスのものはなるべく遠くになるように埋め込み表現を学習する手法です。ArcFaceは、本コンペと類似の過去コンペでも有用性が示されており、本コンペでも参加者の多くがこの手法を用いていたと思われます。したがって、ArcFace自体の利用は差別化ポイントではないため、本記事では詳細は省かせていただきます。

画像モデル

以下の3つをバックボーンとして使用し、損失関数にはArcFaceを用いることで3つの埋め込み表現を得ました。それぞれのモデルから得られる埋め込み表現の次元は512次元です。

  • RegNetY
  • EfficientNet-B3
  • NFNet-L0

様々なバックボーンを使った結果が公開ノートブックに挙げられていましたが、それらのスコア差は比較的小さく、バックボーンの選択自体は本コンペにおいて優先度は高くないと判断し、ほとんど試行錯誤はしていません。しかし、細かいスコアアップにはもちろん繋がるため、より上位を目指す上ではもう少し拘っても良かったかもしれません。余談ですが、上位陣の解法でVision Transformerの一種である Swin Transformer [2]を使っているケースもいくつか見られたので、画像コンペにおいてもすでにTransformerが威力を発揮していることが驚きでした。

テキストモデル

以下の2つをテキストのモデルとして使用しました。画像モデル同様にあまり事前学習モデルの選択には時間を掛けていません。

  • Paraphrase-XLM-multilingual [3]
  • TF-IDF

商品タイトルには英語とインドネシア語が混ざっていたため、多言語で事前学習された言語モデル( Paraphrase-XLM-multilingual )を使用しました。他の参加者の解法を見ると、インドネシア語で学習された IndoBERT [4]を使用しているケースも多かったようです。
また、一般的な文章と比較して、商品のタイトルは単なる単語の羅列に近いため、TF-IDFのような単語をカウントして重み付けするだけのシンプルな手法でも有効だったと考えられます。実際にdiscussionなどを見ていても、多くの参加者がTF-IDFの有用性に言及していました。

以下では、公開ノートブックやdiscussionでは触れられていなかったものの、スコアアップに効果があったものをいくつか紹介します。

GeM pooling[5]

過去に行われた画像検索コンペのGoogle Landmark Retrieval 2019における1位の解法[6]を参考に、pooling層にGeneralized-mean (GeM) pooling(パイプライン図の緑色部分)を使用しました。pooling層の出力は以下の式で得られます。

\displaystyle{\mathbf{f}^{(g)}=[f_1^{(g)}...f_k^{(g)}...f_K^{(g)}]^{\top}, \quad f_k^{(g)}=\left(\dfrac{1}{|\mathcal{X_k}|}\displaystyle\sum_{x\in{\mathcal{X_k}}}x^{p_k}\right)^\frac{1}{p_k}}

ここで、\mathcal{X}_{k}k番目の特徴量マップであり、k\in{{1,...,K}}です。パラメータp_kは学習も可能ですが、簡単のため、論文著者らがGithubに公開しているコード[7]のデフォルト値(p=3)を使用しました。なお、p_k\rightarrow\inftyの場合はmax pooling、p_k=1の場合はaverage poolingに対応します。ちなみに、GeMの効果は後述するDBAやGraph-based QEと比較すると非常に小さかったです。

Database Augmentation (DBA)[8]

こちらもGoogle Landmark Retrievalの上位解法[9]を参考にしました。パイプライン図では、黄色部分になります。DBAは非常にシンプルな手法で、各商品の特徴量に対して近傍k個の特徴量との重み付き和を計算し、それを元の特徴量と置き換えるものです。ここでの特徴量とは、画像とテキストのモデルから獲得した埋め込み表現をPCAによって512次元に圧縮した後の埋め込み表現となります。

\displaystyle{\mathbf{x}_{new}=\sum_{i=1}^{k}w_i \mathbf{x}_i}

ここで、\mathbf{x}_iは元の特徴量からi番目に近い特徴量で、i=1は常に自分自身の特徴量です。パラメータkの値を大きくしすぎると、クエリと距離が遠い別の商品の特徴量も含めてしまうため、調節が必要となります。kの値はいくつかのパターンで試しましたが、k=2の場合が最も良い結果となりました。以下の図は横軸に自分を含めたマッチング商品数をプロットしたヒストグラムですが、この図から分かるように、マッチング商品数が2個の商品が最も多いことがk=2で最も良い結果になった理由と考えられます。つまり、同一商品が2個の商品に対してk\geqq3としてしまうと、異なる商品まで含めてしまうため、特徴量に悪影響を及ぼします。

f:id:k_kawabata:20210601122222p:plain
マッチング商品数に関するヒストグラム

重みに関しては、[9]を参考にして、w_1=1.0, w_2=0.67としました。ただ、この重みの決め方はクエリと最近傍商品の類似度を考慮できておらず、似た商品であろうと似ていない商品であろうと、近傍に対して固定の重みを掛けます。そこで、重みを固定値にするのではなく、類似度(の冪乗)を使うことでよりDBAの質が上がると考えられます。実装自体は簡単なのですが、なぜか私はここで面倒臭がってしまい、結局類似度を使った重み付けは行いませんでした。しかし、コンペ終了後に上位の解法で使われているのを見ると、やはりこれを行っていた方が良かったと後悔しました。

Graph-based QE

先に紹介したDBAと同様に、過去の類似コンペの上位解法でよく使われている手法にQuery Expansion (QE)[10]があります。これは、あるクエリに対して何らかの方法で新しい別のクエリを作成し、元のクエリと合わせて2つのクエリを使って検索をする手法です。この利点として、効果的な新しいクエリを追加できれば、元のクエリだけでは検索でヒットしなかった商品もヒットさせることできるようになります。では、どのように新しいクエリを作成するかというと、よく使われる方法としては、クエリの近傍k個の特徴量との平均を取るものや、類似度(のα乗)を使った重み付き平均などがあります。後者は、α-QE [5]と呼ばれ、\alpha=0の時普通の平均と一致します。

私は実装を簡単にするのと、計算時間を短縮するために、よりシンプルな手法を用いました(パイプライン図のオレンジ部分)。具体的には、新しいクエリを作成する代わりに、すでに同一商品と予測されている商品群を新しいクエリとみなしました。例えば、クエリQに対し予測によって以下のようなグラフが描けたとします。

f:id:k_kawabata:20210601122446p:plain
クエリQに対する予測のグラフ例

ここで、Qからエッジが引かれている商品(A, B, C)はQと同一商品と予測されたものです。この時、QDの間にエッジはありませんので、もしDQと同一商品だった場合は、Dを見逃してしまいます。ですが、すでにQと同一と予測されている商品Cを新しいクエリとみなすと、Dも予測結果に追加することが可能となります。クエリからの近傍k個をナイーブに新しいクエリに追加するのではなく、先に類似度に対してある閾値でスクリーニングを掛け、残ったもの(同一商品と予測されたもの)のみを新しいクエリとして追加することで、False positiveをなるべく上げないようにしているのがポイントです。

また、今回の解法の中でTF-IDFを使用していますが、TF-IDFのように単語の意味を考慮できない単純な特徴量では、ちょっとした特徴量の変化でも意味的には大きく変化することもあるため、近傍の特徴量を利用することはFalse positiveを増やし、逆効果となることが考えられます。これは、DBAも同様です。したがって、TF-IDFに対しては、DBAやGraph-based QEは使用していません。

上位解法との差分

以下では、自分の解法には含まれていないものの上位解法には含まれており、さらなるスコア向上のためには重要だったと考えられるテクニックについていくつかご紹介します。

2nd stageモデル

2位、3位、5位、10位のチームが2nd stageモデルを作成していました。これは、画像やテキストの埋め込み表現から類似度を計算し、ある閾値で同一商品を予測するのではなく、類似度や距離を元にもう一段階モデルを組んで最終的な予測とするものです。具体的には、同一商品の候補となる商品ペアに対して、それらの類似度や距離などを特徴量として、LightGBMやXGBoostなどで対象ペアが同一商品かどうかを予測します。テストデータには約70000件の投稿商品が存在し、ペアの組み合わせ数が膨大であるため、いかにこの処理を高速化するかが重要だったようです。高速化には、cumlのForestInference [11]というGPUを使って推論を行うライブラリが有用で、数十倍の高速化ができるようです。

画像とテキストの予測の組み合わせ

今回のコンペでは、画像での予測結果とテキストでの予測結果をいかに上手に組み合わせるか、という点も重要だったように思います。画像が似ている商品(下図の黄色領域)とテキストが似ている商品(緑色領域)だけではなく、画像とテキストがどちらもそこそこ似ている商品(青色領域)も予測に加えることでスコアが向上したようです。下の図の例では、クエリQに対して、自分自身を含む[Q, A, D, E, F, G]を予測結果とします。

f:id:k_kawabata:20210603143128p:plain
予測の組み合わせ(1位解法[12]から抜粋)

類似度を用いたDBAやQE

DBAやGraph-based QEの項でも触れましたが、クエリ近傍の特徴量に対して何らかの方法で重み付けして新しい特徴量を得る際には、固定の重みではなく類似度を利用する方がやはり精度は良くなるようです。DBAやQEを使っていたチームはほぼ全て類似度を使った重み付けをしていたのではないかと思われます。このやり方自体には気付いていたので、面倒臭がらずに実装していれば良かったと後悔しています。

Kaggle初参加を振り返って

今回、Kaggleに初めて参加しましたが、2、3年ほど前からKaggleの存在自体は知っていました。ただ、業務が忙しくKaggleに割く時間がない、と自分の中で勝手に言い訳をして、長いこと参加してきませんでした。しかし、業務で用いるアプローチが、すでに自分が知っているものや過去に使ったことのあるものばかりであることに気付き、もっとアプローチの幅を広げたいという思いから、最新技術に触れることができるKaggleに参加することにしました。私は普段、テーブルデータを扱うことが多いため、画像やNLPの技術も使えるようになりたい、という思いもありました。今回Shopeeコンペを選んだのもそのためです。

序盤は、初参加でしかもソロ参加だったため、勝手が分からず、「Submission CSV Not Found」を連発し、苦労しました。ただ、参加したのがコンペ終了まで残り1ヶ月を切った頃で、すでに公開コードやディスカッションが充実していたため、取り掛かりやすく、タイミング的には良かったと思っています。一方で、他の参加者(特に上位陣)と比較すると、自分は実験の数が圧倒的に足りなかったなと思いました。より高いスコアを目指す為には、公開されているノートブックやdiscussionなどの情報を単に鵜呑みにするのではなく、問題の本質を見抜く為に自分の頭でしっかり考え、データを確認し、その上で多くの実験を繰り返すことが重要だと感じました。

実際に参加してみて思ったのは、Kaggleは世界中のデータサイエンティストが自分たちの知識や技術を惜しげもなく共有し、機械学習の知見を深めるには非常に効果的な場であるということです。今はまだ、恩恵にあずかるだけですが、いつか自分もこのコミュニティに知見を還元できるようになりたいと思いました。次回参加する際には、ぜひ金メダルを獲得したいです。

参考リンク

[1] https://arxiv.org/pdf/1801.07698.pdf
[2] https://arxiv.org/pdf/2103.14030.pdf
[3] https://arxiv.org/pdf/2004.09813.pdf
[4] https://arxiv.org/pdf/2011.00677.pdf
[5] https://arxiv.org/pdf/1711.02512.pdf
[6] https://www.kaggle.com/c/landmark-retrieval-2019/discussion/94735
[7] https://github.com/filipradenovic/cnnimageretrieval-pytorch/blob/master/cirtorch/layers/pooling.py
[8] https://arxiv.org/pdf/1610.07940.pdf
[9] https://www.kaggle.com/c/landmark-retrieval-challenge/discussion/57855
[10] https://www.robots.ox.ac.uk/~vgg/publications/papers/chum07b.pdf
[11] https://medium.com/rapids-ai/rapids-forest-inference-library-prediction-at-100-million-rows-per-second-19558890bc35
[12] https://www.kaggle.com/c/shopee-product-matching/discussion/238136