エクサウィザーズ Engineer Blog

株式会社エクサウィザーズのエンジニアチームブログ

AI技術を社会実装して課題解決に挑むチームの「技術を理解する&伝える」お話

はじめに

  • こんにちは。AIインキュベーション室で室長をしています長谷川大貴と申します。 エクサウィザーズに入ってそろそろ5年が経過しようとしておりまして、ずっと事業開発やAIを活用したプロジェクトの立ち上げと推進等を数多く実施しておりました。 元々理系出身ではありますが、非エンジニア職です。
  • これまでどんなことしてたの?と言うことをちょろっと知れる記事はこちらにありますので、もしご興味とお時間がある方があればご参照ください。
  • 記事①:「実はここがエクサウィザーズの起源!」AIP西日本事業部ってどんなことやってるの?https://www.wantedly.com/companies/exawizards/post_articles/383676#=
  • 記事②:垣根を超え、想いを形に--。関西からAIの力で情報・教育格差を無くしたい。https://note.exawizards.com/n/n602eb1dcb253

  • そんな非エンジニアがなぜこのブログに!?と思われるところかと思いますが、ずっと事業開発やビジネス側で活動してきたのですが、2022年10月から技術統括側のAIインキュベーション室と言う組織を立ち上げさせていただきエンジニアリングチームとして活動しています。そこで、「技術とビジネスを結び付け、新しい価値を創造する」と言う事に取り組んでおります。

  • 例えば、エンジニアが「新しい技術シーズを開発したけどこれって課題解決に活用できないかな?」と言う技術を理解してニーズに繋いだり、企業等のニーズに対してそう言った技術シーズの蓄積から解決につながる提案を実施したりしています。
  • 私自身はAIの研究者でもゴリゴリコーディングをしているわけでもないのですが、何件もの技術を理解し、活用のアイデアを検討しています。今回は技術を理解する時に私が心がけていることを少しシェアさせていただければなと思っております。 ビジネス側の方は「そういう感じで技術理解しているのかー」と思っていただいたり、エンジニア側の方は「こう言う観点で技術を伝えたら理解されやすいのかな―」と言うヒントに使っていただけると嬉しいです。

「こんな技術があるんですけど!」と言われたら(言いたくなったら)

  • 技術の紹介や相談を受けるとき個人的に大事に思ってるのは双方のリスペクトと知的好奇心の強さです。リスペクトは仕事をするうえでは当たり前なのですが、ビジネスパーソンとエンジニアだったり、経営者とエンジニアだったりは一般的には思考方法が異なるので、双方向の尊重は前提にあるべきだと考えています。
  • 知的好奇心については、私自身は元々強い方だと思っているのですが、「え?面白いじゃないですか!どんな技術なの!?どんどん質問しちゃっていいですか!?」と言うスタンスで聞けることが大事かなと思ってます。その方が、技術の理解や深堀のスピードが上がりますし、活用のアイデアも技術をヒアリングしていく中で出やすい雰囲気になると思います。技術を紹介する時に何となく相手が引いているなと思ったら、相手の知的好奇心をくすぐる情報をアイスブレイク的に放り込むのもありだと思います。
  • 私がお客さんに技術を伝える時、知的好奇心をくすぐるために「ついつい誰かに話したくなる一ネタを入れる」と言う事をよくやっています。 例えば、最近話題のChatGPTは非常に高い精度でテキスト生成をすると言うモデルですが、文章を生成すると言う事は1文字誤りやノイズがあれば、意味が変わってくるケースがあります。実際にあった答えなのですが、「免疫力をつけるにはどうすれば良いですか?」と言う問いの答えの一つに「週あたり7時間以上の睡眠を心がけましょう」と言うのがありました。恐らく正しくは「日あたり7時間以上の睡眠を心がけましょう」だと思うのですが、1文字違うと途端にハードコアな働き方を要求する文章に変わってしまいます。 生成にはこう言う誤った情報を生成してしまうリスクが少なからずあるので、もし可能な限り情報ソースに忠実な文章要約をAIで実現したい場合は、上記のような生成系の技術ではなく抽出型要約(元の文章を抽出して要約を作る方法)を選択することもできますよと言う少しクスッとするエピソードとともに紹介すると「確かに」と技術選択の腹落ち感が増したり、「この要約AIのポイントは元の文章は維持して意味が伝わる内容に要約することなんですよ。生成系技術を使うとこんなリスクもあるんですよ~」とお客さんも社内に技術の特徴を紹介しやすくなりますので、そう言う一ネタを入れることをよくやってます。

「どんな技術か」を理解する(伝える)

  • 前述のようなスタンスで技術をヒアリングしたり議論したりする中で私がおおよそいつも聞くようにしている要素を参考にご紹介します。技術の種類によっては他の色々なことももちろん聞きますが、下記の要素はどう言ったものであっても知っておくべきポイントだと思って聞いています。
  • それぞれを網羅的に聞いていくというよりは、できるだけ自分も知的好奇心を満たすような流れで自然に埋めていけるようにしようとするコミュニケーションの工夫もよくしています。「その技術って○○○ができるんですよね?それって×××みたいな用途でも使えたりしますか??・・・でももしかしたらこういう時は使えないですか??」等々

1. 【概要】どんな技術?何ができる?

  • その技術ができることやどのような技術であるかを自分の言葉で説明できる程度に理解したい。伝聞で自分が説明した時でもユーザーに技術の良さと面白さを伝えられるようにしたいと思い聞いています。

2. 【背景】なぜこの技術に注目した?開発しようと思った背景は?

  • エンジニアがときめいたポイントや良いと思った背景があれば聞きたい。もしくはその他の理由「得意な技術だから、過去にしっかり実装した経験があるから等」があるのであれば、バックグラウンドを理解して、技術の魅力を伝えられるようになりたいと思い聞いています。

3. 【原理】技術や手法の仕組みやポイントは?

  • 技術の原理や仕組み、実装の工夫やポイント等を理解し、技術ができることをある程度の深さの原理のレベルで理解したい。そうすることでユースケースを考えるときの実現性やユーザーの納得感「確かにそういう技術であれば実現できそうだ」を付加するために聞いています。

4. 【強み】既存技術と比べてすごい点はある?特徴はある?

  • 技術の強みや特徴、他との差別化ポイントを把握して、「これまでできなかったことができるようになった」だったり、「これまで実現しようと思ったら時間や費用が掛かっていたものが少なくて済むようになった」等の技術の強みによってユースケースの付加価値が高いポイントを作るために聞いています。

5. 【限界】利用の前提や制約、注意点はある?

  • できないことを明確にし、前提や条件がある時にはそれを含めてユーザーに提案し「嘘」や「誇張」、「過剰な期待値」を生まない誠実な提案を作成するために注意して聞くようにしています。

6. 【応用】技術の応用範囲は?こんなことには使える?あんなことには使える?

  • こう言う用途にも使えるか?工夫したらこう言うこともできるか?等々技術を応用可能な幅を具体的なユースケースの案を仮説として提示しながらできる範囲をイメージしていきます。ここでどの程度の幅でアイデアが出るか、具体的にユーザーが欲しいと思うものの仮説を出せるかがビジネスパーソンが普段接している顧客のニーズの質や量に差が出て面白い部分だと思います。

  • 数が出て議論が盛り上がるケースもあれば、とても具体的かつ根が深いニーズにマッチして盛り上がるケースもあるので、ある程度色々なメンバーで議論した方が面白い応用案が出て良いと個人的には感じています。例えば、エクサウィザーズと言う会社は面白い会社で多様なバックグラウンドの方が多くいらっしゃるのですが、ビジネスとエンジニアで議論していて煮詰まった時に、ケア事業を実施している介護士の方から「その技術であればこういうことにもしかして使えないか?」と言う発言をいただいて思いもよらなかったユースケースが誕生することもあったりします。

「何に使えるか」を想像し、実際にユーザーに提案をしに行く

  • 上記までで技術の概要が理解できれば、技術の強みや限界を加味したうえで、今まで蓄積しているニーズ等を思い返しながら「こう言ったユースケースだと欲しい人がいるのでは」と言う仮説を作ります。その仮説を作ったら実際ターゲットに思い描いていたユーザーに実際に提案しに行ったりディスカッションによってニーズを確認しに行きます。 そこで一発で「欲しい!」となることは稀で、反応を踏まえて提案のチューニングをしたり仮説の修正をしたりすることが多いです。結構この辺りが大変なのですが、こちらはまた機会があれば・・・

まとめ

  • 技術をビジネスに転換していく際にはエンジニア、ビジネスパーソン、デザイナー等々のそれぞれのメンバーが技術と相手をリスペクトし、知的好奇心を持ってディスカッションを実施した方がポジティブな議論ができる
  • 技術を伝える&聞くときにはある程度普遍的に共有した方が良い内容は事前に整理したり意識しながら共有した方が理解は加速する
  • もし技術&ビジネス間の連携のお話やエクサウィザーズと言う会社に少しでもご興味を持っていただけましたら、下記サイトをご覧くださいhttps://hrmos.co/pages/exawizards/jobs?

Apolloを利用したGraphQLリクエストのパフォーマンスをFirebase Performance Monitoring で測定する

こんにちは。エクサウィザーズの介護記録AIアプリ「CareWiz ハナスト」(以下ハナスト)でiOSアプリ開発を担当している伊賀(@iganin_dev)です。

ハナストのテックリードの原のブログ記事にもありましたように、ハナストではAPI通信にGraphQLを利用しています。 本稿ではiOSアプリの通信ライブラリとしてApolloを用いた場合のGraphQLリクエストのパフォーマンスをFirebase Performance Monitoring(以下FPM)を使用して測定する方法に関して記載します。

環境

本稿記載の内容は以下環境を前提に記載しています。

  • Xcode 14.0
  • apollo-ios 0.51.0 (※ 1.0.0へのバージョンアップ検証中)
  • firebase-ios-sdk 9.6.0

ハナストについて

本題に入る前に「CareWiz ハナスト」に関して簡単にご紹介します。 ハナストは簡単に言うと「音声入力で介護の記録をするアプリ」です。

以下のLPによくまとまっています。 利用シーンを紹介するデモビデオもありますので、是非ご覧ください。

hanasuto.carewiz.ai

Firebase Performance Monitoringとは

FPMはGoogleが提供しているFirebaseの機能群の一つです。 ネットワークリクエストをはじめ、さまざまな処理にかかった時間や処理の結果(成功・エラー)などを記録、集計しGUIを通してグラフィカルに確認することができます。

ライブラリを追加するのみでアプリの起動時間や画面の滞在時間などを測定してくれる非常に便利なツールです。 ネットワークリクエストも自動的に計測し、カスタムURLパターンを作成すれば特定のリクエストの計測もできます。

FPMでGraphQLリクエストのパフォーマンスを測定する場合の問題点

非常に便利なFPMですが、GraphQLのリクエストを測定しようとした場合に問題が発生します。GraphQLのリクエストは一般的には同一エンドポイントへのPOSTリクエストとなります。例えば、 https://sample.com/graphql のようなPOSTリクエストのBodyにQueryやMutationのGraphQLドキュメントをのせリクエストを送ります。

FPMではリクエストのパフォーマンスをURLのパスを元に分類します。従って、GraphQLリクエストのパフォーマンスを測定しようとした場合、そのままではすべての計測結果がsample.com/graphqlのような単一のエンドポイントに集約されてしまい、各リクエストのパフォーマンスを別々に見ることができません。それではどのようにすれば、GraphQLリクエストのパフォーマンスをFPMで測定できるのでしょうか。

カスタムネットワークリクエストトレースについて

FPMでは自動収集するリクエストトレース以外に、開発者にて実装できるカスタムネットワークリクエストトレースを用意しています。HTTPMetricをurlとhttpMethodを引数で渡して初期化し、start()stop()を呼び出すことで、start()からstop()を呼び出すまでの時間を計測することができます。FPMのGUI上ではここで指定したurlがパスの分類に使用されます。また、HTTPMetricにはリクエストやレスポンスのペイロードサイズ、レスポンスのステータスコードを登録することもできます。

guard let metric = HTTPMetric(url: url, httpMethod: .post) else { return }
metric.start()
metric.requestPayloadSize = requestPayloadSize

Task {
    do {
        // ネットワークリクエスト実行
        let (data, response) = try await URLSession.shared.data(from: url)
        metric.responsePayloadSize = responsePayloadSize
        metric.responseCode = response.httpResponse.statusCode
        metric.stop()
    } catch let error {
        // エラーハンドリング
        metric.stop()
    }
}

ハナストでの解決方法

先ほどの全てのGraphQLリクエストの計測結果がまとめて集約されてしまうという問題に対して、ハナストではカスタムネットワークリクエストトレースの仕組みを活用して対処しています。基本的な解決方法としてはリクエストごとにURLのパスを分けるというものです。

Apolloを用いて自動生成されたQueryやMutationのclassにおいて、そのOperationの名称をoperationNameで取得することができます。例えば、以下のようなQueryをベースにSampleQuery.graphql.swiftのようなファイルが生成されます。

query sample {
    user {
        id
        name
    }
}
public final class SampleQuery: GraphQLQuery {
    ...
    public let operationName: String = "sample"
    ...
}

このoperationNameは作成したGraphQLファイルのqueryやmutationの名称と1対1で対応するため、operationNameをもとにリクエストを一意に特定することができます。そのため、さきほどのHTTPMetric初期化時に渡すurlにこのoperationNameを付与することでリクエストのPathがGraphQLのリクエストごとに異なるようになり、FPMのGUI上で各リクエストのパフォーマンスを測定することができるようになります。

実装

先ほどoperationNameをURLに付与してリクエストを一意に特定することで問題に対処するとお伝えしました。次にApolloを用いた場合の実装例をご紹介します。具体的な実装に入る前に、Apolloを用いた実装を確認しておきましょう。 まず、Apolloの初期化は下記のように行われます。

let cache = InMemoryNormalizedCache()
let store = ApolloStore(cache: cache)
let client = URLSessionClient()
let transport = RequestChainNetworkTransport(
        interceptorProvider: SampleInterceptorProvider(
            store: store,
            client: client,
            apiConfig: apiConfig
        ),
        endpointURL: endpointURL
    )
ApolloClient(networkTransport: transport, store: store)

Apolloからリクエストを送ると、interceptorProviderfunc interceptors(for _: some GraphQLOperation) -> [ApolloInterceptor]メソッドが提供するApolloInterceptorが通信リクエストの内容とレスポンスにさまざまな処理を加えたり、それらをもとにさまざまな処理を行い、最終的な返却値を返します。なお挙動を確認したところ、func interceptors(for _: some GraphQLOperation) -> [ApolloInterceptor]はApolloからリクエストを送る都度呼び出されていました。

final class SampleInterceptorProvider {
    func interceptors(for _: some GraphQLOperation) -> [ApolloInterceptor] {
        var interceptors: [ApolloInterceptor] = []
        // pre fetch interceptors - fetch前に行いたい処理のinterceptorをここに実装します
        interceptors.append(NetworkFetchInterceptor(client: client))
        // post fetch interceptors - fetch後に行いたい処理のinterceptorをここに実装します
    }
}

今回の実装では、FPMの計測開始を担当するStartFPMMetricInterceptorとFPMの計測終了を担当するStopFPMMetricInterceptorを実装し、interceptorsNetworkFetchInterceptorの前後に追加することでリクエストの都度カスタムネットワークリクエストトレースが行われるようにしています。 StartFPMMetricInterceptorの実装は下記のようになっています。

final class StartFPMMetricInterceptor: ApolloInterceptor {
    init(performanceMonitor: any NetworkPerformanceMonitorable) {
        self.performanceMonitor = performanceMonitor
    }

    private let performanceMonitor: any NetworkPerformanceMonitorable

    func interceptAsync<Operation>(
        chain: RequestChain,
        request: HTTPRequest<Operation>,
        response: HTTPResponse<Operation>?,
        completion: @escaping (Result<GraphQLResult<Operation.Data>, Error>) -> Void
    ) where Operation: GraphQLOperation {
        let operationName = request.operation.operationName
        let url = request.graphQLEndpoint.appendingPathComponent("/\(operationName)")
        let requestPayloadSize = try? request.toURLRequest().urlRequest?.httpBody?.count
        performanceMonitor.start(url: url, method: .post, requestPayloadSize: requestPayloadSize)
        chain.proceedAsync(request: request, response: response, completion: completion)
    }
}

ApolloInterceptorに準拠した場合、interceptAsync<Operation>メソッドを実装する必要があります。 このメソッド内で、requestからoperationNameを取得し、urlのパスにoperationNameを追加しています。 StopFPMMetricInterceptorの実装は下記のようになっています。

final class StopFPMMetricInterceptor: ApolloInterceptor {
    init(performanceMonitor: any NetworkPerformanceMonitorable) {
        self.performanceMonitor = performanceMonitor
    }

    private let performanceMonitor: NetworkPerformanceMonitorable

    func interceptAsync<Operation>(
        chain: RequestChain,
        request: HTTPRequest<Operation>,
        response: HTTPResponse<Operation>?,
        completion: @escaping (Result<GraphQLResult<Operation.Data>, Error>) -> Void
    ) where Operation: GraphQLOperation {
        let statusCode = response?.httpResponse.statusCode ?? 0
        let responsePayloadSize = response?.rawData.count
        performanceMonitor.stop(statusCode: statusCode, responsePayloadSize: responsePayloadSize)
        chain.proceedAsync(request: request, response: response, completion: completion)
    }
}

ネットワークのレスポンスからstatusCodeやレスポンスペイロードサイズを取得し、HTTPMetricに渡しています。 NetworkPerformanceMonitorableはFPMのHTTPMetricの生成や保持、startstopの呼び出しを責務に持つプロトコルです。 NetworkPerformanceMonitorableと、それに準拠したFirebaseNetworkPerformanceMonitorの実装は下記のようになっています。

計測の開始時点で呼び出すstart(url:,method:)でurlとリクエストのペイロードサイズを渡し、そのタイミングでHTTPMetricを生成して保持しています。その後、計測の終了時点で呼び出すstop(statusCode:, responsePayloadSize:)にてレスポンスのペイロードサイズとstatusCodeをわたし、HTTPMetricstopを呼び出しています。

public protocol NetworkPerformanceMonitorable {
    func start(url: URL, method: PerformanceMonitorHttpMethod, requestPayloadSize: Int?)
    func stop(statusCode: Int, responsePayloadSize: Int?)
    func cancel()
}

public final class FirebaseNetworkPerformanceMonitor: NetworkPerformanceMonitorable {
    public init() {}
    private var metric: HTTPMetric?
    
    public func start(url: URL, method: PerformanceMonitorHttpMethod, requestPayloadSize: Int?) {
        let metric = HTTPMetric(url: url, httpMethod: method.convertToFirebaseHttpMethod())
        metric?.requestPayloadSize = requestPayloadSize ?? 0
        self.metric = metric
        metric?.start()
    }
    
    public func stop(statusCode: Int, responsePayloadSize: Int?) {
        guard let metric else { return }
        metric.responsePayloadSize = responsePayloadSize ?? 0
        metric.responseCode = statusCode >= 0 ?  statusCode : nil
        metric.stop()
    }
    
    public func cancel() {
        metric = nil
    }
}

最後にInterceptorProviderinterceptorsメソッドの返却値に上述のクラス群を追加すれば完成です。

final class SampleInterceptorProvider {
    func interceptors(for _: some GraphQLOperation) -> [ApolloInterceptor] {
        let performanceMonitor = FirebaseNetworkPerformanceMonitor()
        var interceptors: [ApolloInterceptor] = []
        // pre fetch interceptors - fetch前に行いたい処理のinterceptorをここに実装します
        interceptors.append(StartFPMMetricInterceptor(performanceMonitor: networkPerformanceMonitor))

        interceptors.append(NetworkFetchInterceptor(client: client))

        // post fetch interceptors - fetch後に行いたい処理のinterceptorをここに実装します
        nterceptors.append(StopFPMMetricInterceptor(performanceMonitor: networkPerformanceMonitor))
    }
}

これでApolloを用いたGraphQLリクエストの都度StartFPMMetricInterceptorStopFPMMetricInterceptorで計測のstartとstopが呼ばれるようになり、リクエストにかかった時間を計測できるようになります。 GUI上では例えば下記のように表示されます。

補足

今回の実装はGraphQLリクエストのCachePolicyとして.fetchIgnoringCacheCompletelyを指定した場合を想定しています。 他のCachePolicyを使用した場合はさらに考慮が必要になるだろうと思います。

まとめ

今回はハナストでGraphQLリクエストのパフォーマンスをどのように測定しているのかをご紹介しました。

本稿ではご紹介できませんでしたが、エッジでの音声認識、ウェイクワード検知、フルSwiftUIでのアプリ開発、Swift Concurrencyの実践的導入などハナストiOSアプリの開発はチャレンジングで面白い課題に日々挑戦しています。

CareWiz事業部およびエクサウィザーズでは社会課題の解決に一緒に取り組む仲間を募集しています。 介護をより良くするプロダクトの開発、あるいはAIで社会課題を解決するエクサウィザーズに少しでも興味がありましたら、是非ご応募ください!

hrmos.co

hrmos.co

open.talentio.com

PreloadResolverという仕組みを作ってGraphQLのN+1問題に対応した話

エクサウィザーズでハナスト開発チームのTLをしている原です。

ハナストは「音声入力で介護の記録をするアプリ」で、こちらのページでプロダクトの紹介をしています。

hanasuto.carewiz.ai

以前は、ハナストAPIのテストについてこちらの記事で書きました。

techblog.exawizards.com

今回の記事ではハナストのAPIで実践している、PrelaodResolverというGraphQLのN+1問題対応の仕組みを紹介します。

GraphQLのN+1問題

まず、GraphQLのN+1問題がどのように発生するかを簡単に説明します。

例えば

type Query {
  books: [Book!]!
}

type Book {
  id: ID!
  author: Author!
}

type Author {
  id: ID!
  name: String!
}

というスキーマがあったとします。

ここで、BookResolverが

class BookResolver {
  async author(book: Book): Promise<Author> {
    return findAuthorById(book.authorId)
  }

  ...
}

のような実装になっていたとすると

books {
  author { name { name }  }
}

というクエリでデータを取得した時に、Bookの数だけfindAuthorByIdが実行されてしまいます。

できればAuthorは

findAllAuthorsByIds(books.map((_) => _.authorId))

のような処理でまとめて取得するべきなのですが、そうではなく取得したBookの数だけfind処理が実行されてしまうのがN+1問題です。

DataLoaderを使う場合

GraphQL のN+1問題の対処にはDataLoaderを使うのが一般的です。

DataLoaderを使うと上記の例だと

const authorLoader = async (books: Book[]): Promise<Author[]> => {
  const authorIds = books.map(books.map((_) => _.authorId))
  const authors = await findAllAuthorsByIds(books)
  const authorMap = arrayToMap(authors)
  return books.map((book) => authorMap[book.authorId])
}

というDataLoaderを用意して、こちらをBookResolverに関連付けて、Book#authorが必要な場合はこのDataLoaderを呼んで、その結果からAuthorを取得するようになります。

実際の組み込み方はDataLoaderのライブラリによりますが、大まかな処理はこのようになります。

ハナストGraphQL APIでのN+1問題対策

DataLoaderでも一通りのN+1問題対策はできるのですが、ハナストではPreload Resolverという仕組みを作って、こちらでN+1問題の対策をしています。

PreloadResolverでのN+1問題対策

PreloadResolverを使った仮想コードはこのようになります。

class BookPreloadResolver {
  async preload(
    books: Book[],
    path: string[],
    info: GraphQLResolveInfo
  ): { authors: Author[] } {
    let authors: Author[] = []

    // GraphQLスキーマを分析して、Book#authorの取得が必要か判定する
    if(findFieldInSchema(path, 'author', info)) {
      authors = await this.preloadAuthor(books)
    }

    return { authors }
  }

  private async preloadAuthor(books: Book[]): Promise<Author[]> {
    const authorIds = books.map(books.map((_) => _.authorId))
    authors = await findAllAuthorsByIds(authorIds)
    const authorMap = arrayToMap(authors)

    books.forEach((book) => {
      const author = authorMap[book.authorId]
      book.preloadAuthor(author)
    })
  }

  ...
}

class Book {
  private preloadedAuthor: Author | undefined = undefined

  preloadAuthor(author: Author): void {
    this.preloadedAuthor = author
  }

  loadAuthor(): Author {
    if(this.preloadedAuthor) {
      return this.preloadedAuthor
    } else {
      // preloadされていない場合は例外を投げる
      throw new Error('author is not preloaded.')
    }
  }
}

class BookResolver {
  author(book: Book): Author {
    // preload済みのAuthorオブジェクトを取得する
    return book.loadAuthor()
  }
}

class QueryResolver {
  constructor(
    private readonly bookPreloadResolver: BookPreloadResolver,
  ) {
  }

  async books(info: GraphQLResolveInfo): Promise<Book[]> {
    const books = await findAllBooks()
    await bookPreloadResolver.preload(books, ['books'], info)
    return books
  }
}

PreloadResolverではDataLoaderとは異なり、BookオブジェクトにAuthorをpreloadしています。

参照側はpreload済みのAuthorオブジェクトを返して、preloadされていない場合は例外を投げるようにしています。

このような仕組みにすることで

  • ネストしたN+1問題への対処
  • GraphQL field毎のpreload可否の設定

がやりやすくなります。

以下で、それぞれについて詳しく説明します。

PreloadResolverによるネストしたN+1問題への対処

例えば

type Author {
  location: Location!
}
type Location {
  address: String!
}

のようにBook#authorsからさらにネストしてAuthor#locationを取得する必要がある場合は、以下のようにPreloadResolverを呼び出します。

class QueryResolver {
  constructor(
    private readonly bookPreloadResolver: BookPreloadResolver,
    private readonly authorPreloadResolver: AuthorPreloadResolver,
  ) {
  }

  async books(info: GraphQLResolveInfo): Promise<Book[]> {
    const books = await findAllBooks()
    const { authors } = await bookPreloadResolver.preload(books, ['books'], info)
    await authorPreloadResolver.preload(authors, ['books', 'author'], info)
    return books
  }
}

class AuthorPreloadResolver {
  // locationのpreload処理を実装
}

class AuthorResolver {
  // preload済みのlocation取得処理を実装
}

こうすることでBook#authorだけでなく、Author#Locationもpreloadされます。

GraphQL field毎のpreload可否の設定

例えば

type Query {
  books: [Book!]!
  authors: [Author!]!
}
type Book {
  author: Author!
  category: Category!
}
type Author {
  books: [Book!]
}
type Category {
  id: ID!
  name: String!
}

のようなGraphQLスキーマとなっていて、

  • books { author { name } }
  • authors { books { category { name } } }

はOKだけど、

  • authors { books { author { name } } }

のような循環参照のpreloadは禁止したいというケースがあります。

このような場合、まずPreloadResolverをこのように実装します。

type BookPreloadFields = {
  author?: AuthorPreloadFields
  category?: CategoryPreloadFields
}

type AuthorPreloadFields = {
  books?: BookPreloadFields
}

type CategoryPreloadFields = {}

class BookPreloadResolver {
  async books(
    path: string[],
    info: GraphQLResolveInfo,
    options: { fields: BookPreloadFields }
  ): Promise<{ authors: Author[], categories: Category[] }> {
    let authors: Author[] = []
    let categories: Category[] = []

    // GraphQLスキーマを分析して、authorの取得が必要か判定する
    if(findFieldInSchema(path, 'author', info)) {

      // authorの取得が禁止されていたら例外を投げる
      if(options.fields.author == undefined) {
        throw new Error(`Book#author preload is forbidden at ${path.join('.')}`)
      }

      // 許可されていたらauthorをpreloadする
      authors = await this.preloadAuthor(books)
    }

    // GraphQLスキーマを分析してcategoryの取得が必要か判定する
    if(findFieldInSchema(path, 'category', info)) {

      // categoryの取得が禁止されていたら例外を投げる
      if(options.fields.category == undefined) {
        throw new Error(`Book#category preload is forbidden at ${path.join('.')}`)
      }

      // 許可されていたらcategoryをpreloadする
      categories = await this.preloadCategory(books)
    }

    return { authors, categories }
  }
}

このQueryResolver側では以下のようにPreloadResolverを使います。

class QueryResolver {
  ...

  async books(info: GraphQLResolveInfo): Promise<Book[]> {
    const books = await findAllBooks()

    await this.bookPreloadResolver.preload(
      books,
      ['books'],
      info,
      // books { author { * } category { * } } を許可する
      { fields: { author: {}, category: {} } }
    )
    return books
  }

  async authors(info: GraphQLResolveInfo): Promise<Author[]> {
    const authors = await findAllAuthors()
    const { books } = await this.authorPreloadResolver.preload(authors, ['authors'], info)
    await this.bookPreloadResolver.preload(
      books,
      ['authors', 'books'],
      info,
      {
        fields: {
          // authors { books { author { * } } } は禁止する
          author: undefined,
          // authors { books { category { * }  } } を許可する
          category: {}
        }
      }
    )
    return authors
  }
}

PreloadResolverを使った再帰的なpreload

PreloadResolverを使って再帰的なpreload処理を行うことも可能です。

例えばBookとAuthorを再帰的にpreloadする処理はこのようになります。

class GraphQLPreloadResolver {
  constructor(
    private readonly bookPreloadResolver: BookPreloadResolver,
    private readonly authorPreloadResolver: AuthorPreloadResolver,
  } {
  }

  async preloadBook(
    books: Book[],
    path: string[],
    info: GraphQLResolveInfo,
    options: { fields: BookPreloadFields }
  ): Promise<void> {
    // booksが空の場合は再帰処理を終了
    if(books.length == 0) { return }

    const { authors } = await this.bookPreloadResolver.preload(books, path, info, options)

    // 再帰的にauthorsのpreload処理を行う
    await this.preloadAuthor(
      authors,
      path.concat(['author']),
      info,
      { fields: options.fields.author ?? {} }
    )
  }

  async preloadAuthor(
    authors: Author[],
    path: string[],
    info: GraphQLResolveInfo,
    options: { fields: BookPreloadFields }
  ): Promise<void> {
    // authorsが空の場合は再帰処理を終了
    if(authors.length == 0) { return }

    const { books } = await this.authorPreloadResolver.preload(authors, path, info, options)

    // 再帰的にbooksのpreload処理を行う
    await this.preloadBook(
      books,
      path.concat(['author']),
      info,
      { fields: options.fields.books ?? {} }
    )
  }
}

このGraphQLPreloadResolverを使って

  • book(id: 1) { author { books { name } } }
  • books { author { name } }
  • authors { books { name } }

というpreloadをしようとする場合は、QueryResolverをこのように書きます。

class QueryResolver {
  constructor(
    private readonly graphQLPreloadResolver: GraphQLPreloadResolver
  ) {}

  async book(args: { id: string }, info: GraphQLResolveInfo): Promise<Book | null> {
    const book = await findBookById(args.id)
    if(!book) { return null }
    await this.graphQLPreloadResolver.preloadBook(
      [book],
      ['book'],
      info,
      // books { author { books { * } } }を許可する
      { fields: { author: { books: {} } } }
    )
    return book
  }

  async books(_args: {}, info: GraphQLResolveInfo): Promise<Book[]> {
    const books = await findAllBooks()
    await this.graphQLPreloadResolver.preloadBook(
      books,
      ['books'],
      info,
      // books { author { * } }を許可する
      { fields: { author: {} } }
    )
    return books
  }

  async authors(_args: {}, info: GraphQLResolveInfo): Promise<Author[]> {
    const authors = await findAllAuthors()
    await this.graphQLPreloadResolver.preloadAuthor(
      authors,
      ['authors'],
      info,
      // authors { books { * } }を許可する
      { fields: { books: {} } }
    )
    return authors
  }
}

再帰処理で実装することで、冗長な処理が少なく記述できているかと思います。

PreloadResolverによるN+1問題対応のメリット&デメリット

個人的にはPreloadResolverには以下のようなメリットがあると思っています。

  • 処理の流れがわかりやすい
  • ネストしたN+1問題に再帰処理で対応できる
  • field毎のpreloadの許可・禁止を設定しやすい

一方で、DataLoaderで困らない用途であればDataLoaderを使った方が記述量は少なく済むかと思うので、その辺りは用途に応じて使い分けるのが良いかと思います。

まとめ

今回はハナストのGraphQL APIでのN+1問題への対応としてPreloadResolverというアプローチをご紹介しました。

ハナストチームではGraphQL技術を活用して介護領域での音声AIサービスの開発を行なっており、一緒に働いていただける方を積極的に募集しています。

GraphQL技術を使った社会課題の解決などに少しでも興味がありましたら、ハナストチームおよびエクサウィザーズに是非ご応募ください。

hrmos.co

日経コンピュータ・日経xTECHで、エクサウィザーズの機械学習エンジニアによる連載を掲載しています

 日経コンピュータの5月26日号(日経BP)から、エクサウィザーズの機械学習エンジニアを中心とした著者による長期連載が始まりました。AI技術の最新動向と応用事例について解説していきます。

●5回目はエクサウィザーズ 機械学習エンジニアのサヒリ・モハメッド、浅谷 学嗣が担当しました。

「AIモデルと処理の軽量化 エッジデバイスで必須に」

IoTでエッジデバイスにおけるAI(人工知能)活用が広がっている。コストと性能を両立させるために必須なのがAIモデルの軽量化だ。ただし手法が数多くあり、選択や活用に注意が必要だ。

▽詳しくは下記をご覧ください(外部リンク、2ページ目以降有料)

xtech.nikkei.com

●4回目はエクサウィザーズ 機械学習エンジニアの石丸 裕吾、西日本事業部/エネルギー環境企画部 事業部長の長谷川 大貴が担当しました。

「数理最適化で意思決定 予測×制約で導き出す」

AIのビジネス活用において数理最適化の重要性が高まっている。現場や経営で必要とされる制約条件を考慮できるからだ。成果を得るための意思決定においてさまざまな分野で活用され始めている。

▽詳しくは下記をご覧ください(外部リンク、2ページ目以降有料)

xtech.nikkei.com

●3回目はエクサウィザーズ 機械学習エンジニアの神戸宏之が担当しました。

「追加学習が不要な「GPT-3」 文章生成などビジネス活用も」

「GPT-3」は自然言語処理分野にパラダイム変化をもたらした。テキストを入力するだけで、それに「答える文章」の予測が可能になったからだ。課題は多いが、マーケティング文章の生成などビジネス活用が始まっている。

▽詳しくは下記をご覧ください(外部リンク、2ページ目以降有料)

xtech.nikkei.com

●2回目はエクサウィザーズ 機械学習エンジニアの小野晃司が担当しました。

プライバシー保護の切り札 「連合学習」が普及期に

「連合学習」はデータそのものを収集せず機械学習モデルを作成できる。携帯の予測変換やクッキー代替などでの活用が始まっている。プライバシー重視のヘルスケアや金融などでの活用が有望視されている。

▽詳しくは下記をご覧ください(外部リンク、2ページ目以降有料)

xtech.nikkei.com

●初回はエクサウィザーズ AI技術統括の遠藤太一郎が担当しました。

「AIの精度を左右する3技術 4ギルドで体制づくり」

AIの精度を左右する最新動向として3つの技術を紹介する。「自己教師あり学習」「マルチモーダル」「MLOps」をうまく取り入れる必要がある。4つのギルドから成る組織体制が、最新動向のキャッチアップに欠かせない。DX(デジタルトランスフォーメーション)の推進が企業の重要課題となっており、その差異化の手段としてのAI(人工知能)に対する期待は高まるばかりだ。AIの主たる要素である機械学習は近年どのように進展し、ビジネスに活用されるようになっているのか。本連載では機械学習のビジネス応用を専門とする筆者が、最新動向と企業事例について解説する。

▽詳しくは下記をご覧ください(外部リンク、2ページ目以降有料) xtech.nikkei.com

エクサウィザーズのTLが実践する、開発が遅くならないテストの書き方

この記事について

この記事ではエクサウィザーズの介護記録AIアプリ「CareWiz ハナスト」(以下ハナスト)の開発スピードを維持するために、どのようにテストを書いているかをご紹介します。

内容としては基本的なことかと思うので、ハナスト開発ではどのような基本に則ってテストしているかという感じで読んでいただければ良いかと思います。

書いているのは誰?

この記事はハナスト開発チームのテックリードをしている原(@haracane)が書いています。

ハナストチームでは主にNode.js&TypeScriptでバックエンドAPIを開発していてテストにはJestを使っています。

ちなみにこれまではKotlin&JUnitやRuby on Rails&Rspecなどで開発&テストをしたりしてました。

ハナストについて

ハナストは簡単に言うと「音声入力で介護の記録をするアプリ」です。

以下の動画を見ていただくと、大体どんなアプリかわかるかと思います。

vimeo.com

ハナストの開発プロセス

やや脱線しますが、ハナストの開発プロセスについてはこちらのnote記事によくまとまっています。

note.exawizards.com

とても良いことが書いてあるので是非読んでいただきたいのですが、この記事に関係するところで言うと、ハナストの開発はあくまで「仮説検証のためにやっている」というところがポイントです。

効果的に仮説検証を進めるには開発スピードを維持することが重要になります。

そのためにどのような工夫をしているか、ということをこの記事にはまとめています。

ハナストの構成

続いてハナストの実装について紹介すると、ハナストは大きく分けて

  • バックエンド API
  • 音声認識 AI
  • iOS アプリ

の3つで構成しています。

今回は主にハナストAPIでどのようにテストしているかをご紹介します。

ハナストAPIのテスト

ハナストAPIはGraphQL APIとして提供していて、Clean Architectureで設計しています。

テストは主にGraphQLのリクエスト&レスポンスのテストを書いています。

例えば記録を作成するGraphQL APIのテストだとこんな感じです。

describe('createCard', () => {
  describe('with food input', () => {
    let response

    beforeEach(async () => {
      response = await graphQLRequest(`
        mutation createCard {
          createCard(type: "food", amount: 10) {
            id
            type
            amount
          }
        }
      `)
    })


    it('creates food record & renders food record', async () => {
      const cardId = response.data.createCard.id
      const card = await new CardRepository().findById(cardId)
      expect(card).toMatchObject({
        type: 'food',
        amount: 10,
      })

      expect(response).toEqual({
        data: {
          createCard: {
            id: cardId,
            type: 'food',
            amount: 10,
          }
        }
      })
    })
  })
})

逆にユニットテストは極力書かないようにしています。

上記のような記録作成の例だと、Layer毎に

  • GraphQL Layer
    • MutationResolver#createCard
  • UseCase Layer
    • CardService#create
  • Database Layer
    • CardRepository#create

のようなクラス&メソッドを実装していますが、各メソッドのユニットテストは書いていません。

その理由についてご説明する前に、ハナスト APIのテストの役割についてご説明します。

ハナスト開発におけるテストの役割

ハナスト開発でのテストの役割は

  • API仕様と異なる実装を検出すること
  • 開発スピードを落とさないこと

としています。

この二つに優先順位はなく、品質担保と同様に開発スピードの維持も重要な役割としています。

これはハナスト開発のフェーズが仮説検証を繰り返している段階で、素早く機能を開発する必要があるからです。

必要のない機能は作らない

テストを書く、書かない以前に開発スピードを落とさないためには必要のない機能を作らないということが大事です。

基本的な方針として、ハナスト開発ではAPIの仕様を決める時に

  • 必要のない入力を受け付けない
  • 必要のない出力をしない

ようにしています。

API仕様に必要のない入出力があると、その仕様のテストも必要になりますし、後々変更をする際にもその仕様を守るために余計な工数がかかってきて開発スピードが落ちてしまいます。

また、必要のないテストをしないことも同様に重要です。

同じようなテストがいくつもあると、そのテスト作成の工数だけでなく、テスト変更時の工数も増えてしまいます。

ハナストAPIのリクエストテストを書く理由と、ユニットテストを書かない理由

現在のハナストのように開発が活発な状態だと外部APIおよび内部APIの仕様変更も頻繁に発生します。

リクエストテストは外部APIの実装が壊れないようにするためのもので、これは必要です。

リクエストテストがないと、外部APIが仕様通りに動いていない場合に気付くことができません。

一方ユニットテストについては内部APIの実装が壊れないようにするためのものになります。

ハナストの開発ではリファクタリングによって内部APIの仕様を変えるというケースは頻繁に発生します。

ユニットテストが書いてある場合は、内部APIの仕様を変えた時には関係するユニットテストも合わせて修正する必要があります。

リクエストテストは問題なく通っていて、API全体の動作としては問題ない場合でも、ユニットテストが壊れている場合は直さなくてはいけません。

これは「開発スピードを落とさない」という観点からは許容できません。

そのため、ハナストでは基本的にリクエストテストを書いて、原則としてユニットテストを書かないようにしています。

ユニットテストを書く場合

「原則」ユニットテストは書いていませんが、場合によっては書く場合があります。

一つはユニットテストレベルでTDD(テスト駆動開発)をした方が実装しやすい場合です。

よくあるのは値を変換する汎用ロジックのテストです。

例えばUTCの時刻からJSTの日付を取得するようなメソッドなどが該当します。

このようなメソッドはユニットテストを書いてから実装コードを書くTDDスタイルの方が実装しやすいので、ユニットテストを用意しています。

もう一つは他のテストでスタブしているメソッドのテストです。

リクエストテストでメソッドをスタブしている場合、スタブしているメソッドのユニットテストがないと、そのメソッドが全くテストされません。

ハナストの開発ではテストされないロジックは許容していないので、そのような場合はユニットテストを用意しています。

スタブを使う場合と使わない場合

スタブについても「開発スピードを落とさない」ことを考えて使うかどうかを決めています。

前述したように、スタブを用意した場合はユニットテストが必要になるのでスタブも極力使わないようにしています。

これはS3やSQSなどについても同様で、minioやelasticmqなどのクローンを使ってなるべくスタブはしていません。

ただし、

  • クローンが用意されていない外部リソースを利用する場合
  • 外部リソースのエラー時の挙動をテストする場合
  • 外部リソースの状態変化をテストする場合

といったケースでは外部リソースのスタブを許容しています。

この場合、スタブ対象が十分にテストされていることが望ましいので、なるべくライブラリのクライアントなどを直接スタブするようにしています。

Factoryを積極的に使う

テスト用のデータ(例えばUserデータなど)についてもやはりモックは使いません。

これは、安易にモックしてしまうと矛盾のあるデータが作られてしまいやすく、テストが不安定になるからです。

その代わりにハナストの開発ではテスト用のデータ作成用のFactoryクラスを用意しています。

例えばUserFactoryなら

new UserFactory().create()

という感じで呼び出すと、デフォルトのパラメタでUserデータを作成します。

テストのためにパラメタの指定が必要な場合は

new UserFactory().create({ familyName: 'ハナスト', givenName: '太郎' })

のようにパラメタを指定します。

Factory側では依存するレコードの作成なども含めて、矛盾のないデータを作成するようにしています。

CIは並列実行する

  • 基本的に(ユニットテストではなく)リクエストテストを書く
  • スタブはしない

という方針でテストを書いていると、テストの実行時間は長くなりがちです。

その結果CIの待ち時間が長くなってしまうと、やはり開発スピードが落ちてしまうので、CIに時間がかかるようになってきたら適宜並列化して、全体で5分程度で終わるようにしています。

並列化の方法はシンプルで、例えばテスト用のディレクトリ構成が

  • spec
    • entity
    • usecase
    • request

のような構成になっていたら、entity・usecase・requestディレクトリのテストをそれぞれ並列に実行します。

ハナストのCIはGithub Actionを使っているので、上記のような例ですと3並列のワークフローを定義してCIを実行すればOKです。

まとめ

今回はハナストの開発スピードを落とさないための基本的なテスト戦略についてご紹介しました。

他にもハナストでは音声認識技術や音声入力のインタフェースなどをチームメンバーが各自の役割を発揮して開発していますが、介護の世界でやるべきことはまだまだたくさんあります。

そのような社会的な課題の解決をより進めるために、ハナストチームおよびエクサウィザーズでは一緒に働く人を募集しています。

介護をより良くするハナストの開発、あるいはAIで社会課題を解決するエクサウィザーズに少しでも興味がありましたら、是非ご応募ください!

hrmos.co

因果推論とグラフ理論

こんにちは。数理最適化ギルドでエンジニアをしている加藤です。

ある自社プロダクトの開発を通じて因果推論について勉強する機会がありました。因果推論は統計の分野ですが、その中で数理最適化の技術が使えることを知り、とても面白かったのでその内容をシェアしようと思います。具体的には組合せ最適化問題のひとつである最小カット問題が、因果推論のタスクの一部である識別可能性に利用できるという話をします。

前半は因果推論についての概説で特に予備知識は仮定していないです。後半は計算時間やネットワークフローなどのアルゴリズムを知っていると読みやすいと思います。

因果推論とは

因果推論の目的

統計的因果推論とは事象の間の因果効果を実験データや観測データから推定することを目的とした統計学の一分野です。単に因果推論といった場合は統計的因果推論を含むより広い概念を指すことがありますが、簡単のため以下では因果推論といえば統計的因果推論のことであるとします。

例えばある販促キャンペーンを実施したときに、本当に効果があったのかどうかなどを合理的に判断することなどに利用できます。この例のように物事の原因となったと想定されるものを処置(トリートメント)変数、その原因に影響を受けたと想定されるものを結果(成果・アウトカム)変数と言います。また処置が結果に与えた影響の大きさを因果効果と言います。

因果推論には主にRubin流の潜在的アウトカムを用いる流儀とPearl流の因果グラフを用いる流儀があります。本稿で説明するのは後者のPearl流因果推論です。

因果関係があるどうかを判断するというのは意外と難しく、その理由の一つは人間は相関関係を因果関係を誤認してしまうことにあります。

因果≠相関

相関・因果の誤謬の例としてしばしば引き合いに出されるのが次のグラフです[1]。

チョコレート消費量とノーベル賞受賞者数の関係

これは一流の医学雑誌に掲載された論文の分析で「一人当たりのチョコレートの年間消費量が多い国ほどノーベル賞受賞者が多い」という相関関係がわかります。論文の著者は「国民が1人あたり年間400g多くチョコレートを摂取すると、その国のノーベル賞受賞者の数が1人増える」と結論づけていますがこれは当然正しい推論ではなく、実際は「一人当たりGDP」という変数が前者と後者の変数どちらには影響を与えることによって生み出されている擬似的な相関です。疑似的な相関(あるいは疑似的な無相関)のことを交絡といい、交絡を作り出すの「一人当たりGDP」のような変数を交絡変数と呼びます。こういった擬似相関を排除したい場合はどうすれば良いのでしょうか?

もう少し単純な例で考えてみます。次の表をご覧ください。

治療結果の表

この表は何かしらの疾病・怪我などに対して治療した場合としなかった場合での生存・死亡数を表しています。治療をした場合もしなかった場合も生存率が50%で変わっていないので一見この治療には効果がないように見えます。しかし次の表を見てください。

男女別治療結果の表

この表をみると男性を治療した場合の生存率は約62%、しなかった場合は約57%なので治療によって生存率が上がっています。一方女性も治療した場合が約44%、しなかった場合が40%なのでこちらも生存率が上がっています。つまり一見効果がないように見えた治療でしたが、男女別にすると治療効果が見えてくるということがあります。

このように層別化などによって正しい因果効果が推定できるような変数の集合のことを調整変数と呼びます。また調整変数が観測できるとき因果効果は識別可能であると言います。つまり上の例では「チョコレート消費量→ノーベル賞受賞者数」の因果効果を正しく推定するための調整変数が「一人当たりGDP」なので、この因果効果は識別可能であるということになります。

調整変数選択の難しさ

調整変数によって層別化すれば正しい因果効果を推定できると述べましたが、肝心の調整変数はどうやって選択すれば良いのでしょうか?少なくともなんでもかんでも調整変数に含めれば良いわけではないということが次の例でわかります。次の表を見てください。

幼児が遊んだトランプの表

この表は幼児が遊んだトランプの分類です。数値的には先ほどの表と全く同じです。汚れありのグループをみると絵札より数札の方が赤札の割合が高いです。一方汚れなしのグループも絵札より数札の方が赤札の割合が高いです。つまり「汚れの有無にかかわらず、絵札より数札の方が赤札の割合が高い」と読めます。しかし絵札も数札も赤と黒は同数あるので「絵札より数札の方が赤札の割合が高い」という主張は当然正しくありません。したがってこの例では「汚れの有無」という属性で分けることで間違った結論を導いてしまいます。

上記の例のように層別化することでバイアスが発生してしまうことを選択バイアスと呼びます。*1

このように何を調整変数にすべきか・そもそも調整変数は存在するのか(識別可能どうか)を判断するのは難しい問題ですが、その問いへの回答として冒頭でも触れたPearlの因果グラフの理論があります。続く記事ではその内容を解説します。

因果グラフ

グラフとは

グラフとは物事のつながりを表現するための構造のひとつです。現実世界のさまざまなものをモデル化する力があり、因果推論以外の文脈でも頻繁に現れる対象です。

数学的にはノードとエッジと呼ばれる対象から構成され、ノードはグラフの頂点でエッジはふたつのノードをつなぐ辺を意味します。フォーマルな定義は以下です。ついでに後々使う概念の定義もしておきます。

定義(グラフ)G = (V, E)をグラフという。Vの要素をノードあるいは頂点という。EVの要素のペアからなる集合で、その要素をエッジあるいは辺という。エッジに向きが指定されている場合そのエッジは向きがついているという。またUからVへの向きがついたエッジはU \to Vと書く。またUをエッジの始点、Vをエッジの終点という。

定義(無向グラフ、有向グラフ、単純グラフ)全てのエッジに向きが付いているグラフを有向グラフという。逆に全てのエッジに向きがついていないグラフを無向グラフという。全てのノードのペアに対して高々一つしか辺がなく、全てのノードが自分自身へのエッジを持たないグラフを単純グラフという。以後グラフは全て単純であるとする。

定義(パス、ループ)ノードの列N_0, N_1, \cdots, N_pであって全てのiに対してN_iN_{i + 1}の間にエッジが存在するものをパスという。またそれらが全てN_iからN_{i + 1}へ向きがついている場合、このパスを有向パスという。またN_0およびN_pをそれぞれパスの始点および終点という。始点と終点が一致しているパスをループと呼び、それが有向パスである場合に有効ループと呼ぶ。

定義(親、祖先、子、子孫)有向グラフのノードNに対してNに入ってくるエッジの始点であるノードをNの親、Nから出ていくノードの終点であるノードを子、Nを終点とする有向パスの始点を祖先、Nを終点とする有向パスの終点を子孫という。

定義(DAG)単純な有向グラフであって有向ループを持たないものをDAG(有効非巡回グラフ)という。

水色の丸がノード 青い矢印がエッジこの例は特にDAGになっている

因果推論においてはノードに事象あるいは変数を対応させ、エッジは因果関係を表します。このように因果関係をグラフで表現したものを因果グラフと呼びます。また因果グラフ上でノードUからノードVにエッジが存在するときUVの原因であるということがあります。

交絡の例

因果グラフを使うと交絡変数は次のように表せます。ただしT, Y, Uがそれぞれ処置変数、結果変数、交絡変数を表します。以下で矢印の実線は因果関係を、波線は交絡を表していると考えてください。

交絡変数によって処置変数と結果変数が関係し合っている

交絡変数Uで調整することにより交絡を取り除く様子を次のように表現します。

交絡変数Uを調整することにより交絡を取り除く

逆に調整変数にすべきでない変数による交絡が生じる場合もあります。以下のように調整変数も結果変数も原因であるような変数で調整すると選択バイアスによる交絡が生まれます。

Sを調整変数に選択することによって交絡が生まれる

このように因果関係をグラフによって表現し、交絡変数や選択バイアスを視覚的に表現できます。もう少し複雑な例を考えてみましょう。次のグラフをご覧ください。

複雑な因果グラフ

上記の例ではまずVが交絡変数になっています。またUTVの交絡因子になっており、VYの原因なのでT \gets U \to V \to Yというパスを介してTYは関係しあっています。同様にT \gets V \gets W \to YというパスもありこれもTYの交絡を作り出しています。このような処置変数Tに入ってくるエッジからスタートする結果変数Yへのパスをバックドアパスと呼びます。

3つの交絡

これらの交絡を取り除くためにはどのような調整変数を選択すれば良いでしょうか。実はVを調整変数にすると上記3つの交絡は全て取り除けます。なぜなら上記3つの交絡をもたらすバックドアパスの経路上にVが存在して交絡をブロックしているためです。バックドアパスとパスのブロックのフォーマルな定義は以下になります。

定義(バックドアパス)有向グラフG = (V, E)の頂点T, Y \in Vに対してTからYへの有向とは限らないパスであってTへ入ってくるエッジからスタートするものをバックドアパスという。

定義(パスのブロック)有向グラフG = (V, E)とその頂点の部分集合S \subset VおよびG上の有向とは限らないパス\gamma が次のいずれかの条件を満たすとき S\gamma をブロックするという。

1. \gamma上の合流点であって、自分を含む全ての子孫がSに含まれないものが存在する。 2. \gamma上の非合流点であってSに含まれるものが存在する。

ただしパス上の端点以外のノードであってパス上で隣り合う二点からエッジが入ってくるようなノードを合流点と呼んでいる。

しかし一方で別の問題が発生します。それはVを調整変数にすることでU \to V \gets Wに選択バイアスが生まれT \gets U \to V \gets W \to Yという新たな交絡が生じてしまうのです。この新たな交絡を取り除くためには{U, V}などを調整変数として選ぶ必要があります。

Vで調整すると新たな交絡が生まれるのでUも選ぶ

上のような特殊な例に限らず交絡を取り除くためにはどんな調整変数を選べば良いのでしょうか。その答えが次のd分離性です。

定義(d分離)有向グラフG = (V, E)の頂点T, Y \in Vと頂点の部分集合S \subset Vが次の性質を満たすときS\{T, Y\}をd分離(有効分離)するという。TからYへの全ての有向とは限らないパス\gammaに対してS\gammaをブロックする。

少し面倒ですが上述の特殊な例において\{U, V\}\{T, Y\}をd分離していることが確認できます。

有効分離があるのに無向分離はないのかと思われるかも知れませんが一応こちらもあります。後で使うのでここで定義しておきます。

定義(u分離)有向グラフG = (V, E)の頂点T, Y \in Vと頂点の部分集合S \subset Vが次の性質を満たすときS\{T, Y\}をu分離(無向分離)するという。TからYへの全ての(無向)パス\gammaに対して\gamma上にSの要素が少なくとも1つ存在する。

バックドア基準

さらに一歩進んでバックドア基準についても述べておきます。交絡を取り除いて因果効果のみを知りたいときに有用な概念です。

定義(バックドア基準)因果グラフG = (V, E)とそのふたつの頂点T, Y \in Vに対して次の条件を満たすノードの集合S \subset Vをバックドア基準を満たすという。

  1. すべてのノードZ \in Sに対してTからZへの有向パスが存在しない。
  2. \{T, Y\} \cup Sの祖先グラフ上でSTYをd分離する。

ただし有向グラフG = (V, E)のノードの集合V_0 \subset Vに対してその祖先グラフとは、ノードV^{an}V_0のすべての要素およびその祖先とし、エッジE^{an}Gのエッジであってその始点・終点がV^{an}の要素であるものとするグラフのことである。

やや定義が複雑ですがバックドア基準を満たす変数集合を調整変数として選べば介入効果などが表現できることが知られています[2]。

ところで処置変数の親を全て調整変数にすれば必ずバックドア基準を満たすことも知られています。しかしこれは必ずしも良い選択ではありません。調整変数は層別化や回帰といった用途に用いられる変数なので、できるだけ少ない数の調整変数が知りたいという欲求があります。

つまり我々の目標は次のようになります

バックドア基準を満たす最小の変数集合を見つけたい!

定義よりバックドア基準はd分離性と密接に関わっていることがわかると思いますので次のような目標設定でも良いです。

d分離性を満たす最小の変数集合を見つけたい!

次の章では上の目標を実現するアルゴリズムについて解説します。

調整変数選択のアルゴリズム

d分離性の言い換え

ここではd分離性を満たす最小の変数の集合をどうやって見つけたらよいかを、計算量にも触れつつお話しようと思います。

ひとつの方法は変数の部分集合を全て列挙してその集合がd分離性を満たすかどうかを判定することです。この方法はすぐに思いつきますし、実際にこのように実装されているライブラリもあるようです。実用上はあまり問題がないかも知れませんが、変数の集合が多くなるとこの方法は非常に遅くなります。実際変数がN個だとすると計算時間がO(2 ^ N)もかかるので規模が大きくなると途端に遅くなるでしょう。

アルゴリズムに詳しい人ならd分離性の定義を見たときに「ソース(処置変数)からシンク(結果変数)への全てのパスが何かしらの意味でブロックされているという定義だから、最小カットのようなアルゴリズムが使えそうだ」と思いつくかも知れません。これから述べるようにその直感は正しいのですが、最小カットを応用するためにはd分離性の定義を言い換える必要があります。それが次のモラルグラフを使った言い換えです。

定義(モラルグラフ)G = (V, E)を有向非巡回グラフとする。GのモラルグラフG^mとはVをノードとし、次で定義するエッジの集合をもつ無向グラフことである。Vの2頂点v_1およびv_2に対して次のいずれかが成り立つとき\{v_1, v_2\}G^mのエッジであるとする。

  1. v_1 \to v_2あるいはv_2 \to v_1Gのエッジである。
  2. G内で、v_1およびv_2を共通の親としてもつノードが存在する。

モラルグラフの作り方

ここでモラルグラフの構築にかかる計算量を調べてみると、全てのノードのペアに対して上記2つの条件のうちいずれかを満たすかどうかをチェックするだけなのでO(N^2)程度です。モラルグラフに対して次の定理が成り立つため非常に有用な概念になっています[3]。

定理([3], Theorem 3)G = (V, E)を有向非巡回グラフとする。頂点\{T, Y\}に対してS\{T, Y\}の祖先の集合の部分集合であるとする。このときS\{T, Y\}をd分離することは\{T, Y\}の祖先グラフのモラルグラフH = G_{An(T, Y)}^m内でS\{T, Y\}をu分離することと同値である。

この言い換えによりd分離集合を求めるためには

  1. 処置変数と結果変数の祖先グラフのモラルグラフを構築する。
  2. モラルグラフ内でu分離する頂点集合を求める

というステップを踏めば良いことがわかりました。1.に関しては解説しましたので次は2.について解説していきます。

最小カット問題

モラルグラフ内でu分離する頂点集合を求めるというタスクは最小カット問題で解くことができます。最小カット問題とは次のような問題です。ここだけの用語ですが有向グラフGの2頂点u, vとエッジの集合Fに対してGからFを取り除いたグラフにおいてuからvへの有向パスが存在しないとき、Fu, vをカットするということにします。

有向グラフG = (V, E)がある。また全てのエッジにはコストという値が定められているとする。始点Sと終点Gが定められている。始点と終点をカットする「エッジのコストの合計」の最小数はいくらか?

カットの例

モラルグラフのu分離性に使えそうな設定の問題ですね!でも後一歩必要です。最小カット問題はエッジを除去することでパスをブロックする問題ですが、u分離性は頂点によってパスをブロックするからです。このギャップを埋めるためには次のようなグラフを考えれば良いです。

モラルグラフHのノードuに対してu_+およびu_-というノードとu_+ \to u_-というエッジを作りコスト1を与えます。またuvをつなぐエッジがHに存在する場合、u_- \to v_+およびv_- \to u_+というエッジを作りコスト\inftyを与えます。以上で作ったノードとエッジからなるグラフを\tilde{H}と書きます。

最小カットに帰着する

H\tilde{H}に対して次のことが言えます。

H内でTYをu分離するノードの集合の最小数と、\tilde{H}内でT_-Y_+の間の最小カット数は同じである。また次のように一方の問題の解から他方の問題の解への変換ができる。

  1. HにおいてST, Yをu分離するノードの集合だとすると、\tilde{H}においてF = \{u_+ \to u_-\ | u \in S\}T_-, Y_+をカットする。
  2. \tilde{H}においてFT_-, Y_+をカットするエッジの集合だとすると、HにおいてS = \{u | u_+ \to u_- \in F\}T, Yをu分離する。

これでモラルグラフのu分離集合の求め方がわかりました!これとd分離性のモラルグラフでの言い換えと合わせると結局d分離性を満たす最小の変数集合を高速に計算できることになります!

また使ったアルゴリズムも最小カットなので適当なソルバーを使えば高速に計算できます。\tilde{H}のエッジおよびノードはそれぞれ高々O(N^2)およびO(N)のオーダーなのでDinicのアルゴリズムなどで最悪O(N^4)程度の計算時間になります。愚直な方法が指数時間であることを考えると劇的な改善です!

終わりに

今回は因果推論におけるアルゴリズムの活躍についてお話ししました。筆者の乱文をここまで読んで下さった方には大変感謝いたします。以下のリストは調整変数の選択とアルゴリズムに関連する内容として興味深いですが筆者の勉強と気力が不足していて書けませんでした。またTechBlogを書く機会があってやる気に満ち溢れていたら勉強して書こうと思います。

  • フロントドア基準・操作変数法
  • optimalな調整変数の選択
  • mixed graphでの識別可能性
  • IDアルゴリズム

エクサウィザーズでは一緒に働く人を募集しています。中途、新卒両方採用していますので、興味のある方は是非ご応募ください!

参考文献

  1. Messerli, F. H. (2012). Chocolate Consumption, Cognitive Function, and Nobel Laureates. The New England Journal of Medicine, 367, 1562-1564.
  2. 宮川雅巳. (2004). 統計的因果推論ー回帰分析の新しい枠組みー. 朝倉書店.
  3. Tian, J., & Paz, A., & Pearl, J. (1998). Finding Minimal D-separators. Tech. Rep. R-254, Univ. of California, Los Angeles.
  4. Acid, S., & De Campos, L. M. (2013). An algorithm for finding minimum d-separating sets in belief networks. arXiv preprint arXiv:1302.3549.

*1:選択バイアスはより広い意味で用いられることがありますがこの記事では上記の意味でのみ用います。

Zero-shot Learning入門

こんにちは。エクサウィザーズで画像ギルドに所属し、機械学習エンジニアをしている小島です。今年の3月からこちらにジョインいたしました。

この記事では、弊チームで取り組んいるテーマ「Zero-shot Learning」について、歴史的な背景を振り返りつつ、簡単な実装を紹介します。今研究でホットな研究テーマの一つである「クロスモーダルモデル」を身近に感じていただければ幸いです。

Zero-shot Learningとは

「Zero-shot Learningとは何か」というのは、実は曖昧なテーマです。「これがZero-shotだ」という定義が論文によって異なるためです。わかりやすい理解の仕方としては、Many-Shot Learning、One/Few-shot Learningから天下り的に考えていくことでしょう。

画像系の機械学習の問題は、大きく分けて、タスクの軸データ数の軸の2軸で考えられます。

タスクの軸については、分類問題(Classification)、物体検出(Object Detection)、セマンティック/インスタンスセグメンテーション(Instance/Semantic Segmentation)を最も基本的な3つのタスクとして挙げています。タスクの軸はこれ以外にもいろいろあるので、この3つが正解というわけではありません。

本題はデータ数の軸で、ここでは「タスク固有の訓練データ」を意味します。例えば、「犬か猫か」の画像分類モデルを訓練したい場合、犬の画像を数百枚、猫の画像を数百枚持ってきて、ImageNetで訓練されたモデルをfine-tuningするのが一般的なやり方でしょう。この場合、タスク固有の訓練データは「数百枚×クラス数(犬or猫=2)」必要となり、データ数の軸では「Many-shot」となります。

では、Few/One-shotとは何でしょうか。One-shotとは文字通り、タスク固有の訓練データが1つ、Fewなら数枚か~少し多い程度でしょうか。Few/One-Shotの典型例は顔認証です。例えば、1人あたり数百枚の顔写真を入れ、顔認証をMany-shotの問題として訓練・運用するのは現実的でありません。データが取れないという問題もありますし、認証対象に1人追加されるとクラス数が変わるため、モデル全体を訓練し直さないといけないからです。この図は[1]の論文からのものです。

One-shot Learningでは、「クラスが同一かどうかを学習している」点に注意してください。Many-shotでは、犬や猫といった特定のクラスに属するかを学習していました。ここで大事なのが、学習やタスクの設定を工夫すれば、タスク固有の訓練データは減らせるという点です。ここの改良はZero-shotでも続けられています。

Zero-shot黎明期

ディープラーニングのZero-shotの論文に行く前に、Zero-shotの発端となりうる論文を2つ紹介します。

Zero-shot Learning

1つ目は、Palatucci et al.(2009)によって書かれた論文です[2]。AlexNet[3]が2012年に発表されたので、これはディープラーニングのブームが来る前の論文です。論文のタイトルにも「Zero-shot Learning」とついていますが、本文中にZero-shot Learningについて問題提起しています(※翻訳はDeepLで作成しています)。

Given a semantic encoding of a large set of concept classes, can we build a classifier to recognize classes that were omitted from the training set?
大規模な概念クラス集合の意味論的符号化が与えられたとき、学習集合から漏れてしまったクラスを認識する分類器を構築できるか?

また、アブストラクトでは、Zero-shot Learningを以下のように定義しています。

We consider the problem of zero-shot learning, where the goal is to learn a classifier f : X \to Y that must predict novel values of Y that were omitted from the training set.
学習セットから漏れてしまった新しいYの値を予測しなければならない分類器f : X \to Yを学習する、ゼロショット学習の問題を考える。

訓練データにはないクラスを認識することがポイントで、この観点は現在のZero-Shot Learningでも根底にある思想の1つです。

なお、この論文の著者にはジェフリー・ヒントンとトム・ミッチェルがいます。ジェフリー・ヒントンはディープラーニング界のゴッドファーザーの1人として非常に有名な研究者です。トム・ミッチェルの「機械学習の定義」は非常によく引用されるので、どこかで見た方もいらっしゃるのではないでしょうか。彼の著書[4]からです

A computer program is said to learn from experience E with respect to some class of tasks T and performance measure P, if its performance at tasks in T, as measured by P, improves with experience E.
コンピュータ・プログラムが、あるタスクのクラスTと性能指標Pに関して、経験Eから学習すると言われるのは、Tのタスクにおける性能が、Pによって測定されるように、経験Eとともに向上する場合である。

Zero-data Learning

2つ目はLarochelle et al.(2008)らによって書かれた、「Zero-data Learning of New Tasks」という論文[5]です。「Zero-data Learning」と「Zero-shot Learning」紛らわしいですね。図はZero-data Learningの論文からです。

「Zero-data Learning」は以下のように定義されています。

We introduce the problem of zero-data learning, where a model must generalize to classes or tasks for which no training data are available and only a description of the classes or tasks are provided.
ゼロデータ学習とは、学習データがなく、クラスやタスクの説明のみが提供されているクラスやタスクに対してモデルを汎化しなければならない問題である。

Zero-data Learningの特徴として「学習データがない」という点が明確に記載されています。また、クラスやタスクの説明のみが提供されているということも興味深い。実はこのZero-data Learningの説明こそ、後で紹介するCLIPのやっていることと一致するという点も、頭の片隅に入れておきたい点です。

論文の図を見ると、左側は手書き数字の「1,2,3」のみ訓練データとして与え、「A, B」を推論するケースで、これは完全に学習データがないケースです。真ん中はマルチタスク学習で、欠損値補完のイメージでしょうか。いずれにしても「学習データがないクラスや組み合わせに対して推論する」という点がポイントです。

冒頭で、「現在のZero-shot LearningはMany-shotやFew/One-shot Learningから天下り的に考えると、タスク固有の訓練データがない学習方法だ」ということを述べました。まさにこれが、Zero-data Learningの考え方と重なるわけです。

すなわち、現在のディープラーニングでZero-shot Learningと呼ばれているものは、厳密にいえば、Palatucci el al. (2009)のZero-shot Learningと、Larochelle et al. (2008)のZero-data Learningの少なくとも2つの側面があると言えます。現状では、各論文の著者が「これがZero-shotだ」と言っているので、源流の論文を明確に意識する機会はほとんどありません。しかし、10年以上前の論文の思想が、現在も普遍的に生きていることには驚かされます。

Larochelle et al. (2008)の著者には、同じくディープラーニングのゴッドファーザーと呼ばれるヨシュア・ベンジオが含まれています。AI界の伝説の存在が10年以上前に揃ってこのトピックに取り組んだことから考えると、Zero-shotはそれだけに魅力的な問題なのでしょう。

時代はテキストと画像のクロスモーダルへ

クロスモーダルのアプローチ

Zero-shotを「テキスト」と「画像」のクロスモーダルの問題として捉えたのが、Socher et al.(2013)[6]です。これまでのZero-shot LearningもZero-data Learningも、教師データがあるのを前提にしていました。本論文の大きな違いは、あらかじめ用意された大規模なコーパスから、ラベルのテキストの埋め込み量と外れ値検出を組み合わせて、未知の画像のクラスを推定している点です。

イメージとしては、「truck」と「cat」が未知のクラスだとします。猫の画像が与えられた場合、まず既知のクラスに対して「未知である」という外れ値検出します。そして、テキストの埋め込み量を使い、近傍の「猫」である推定します。ポイントは、大規模なコーパスを使っているため、クラスのラベルとしては、トラックも猫も既知であることです。

なお、この論文の著者には、同じく機械学習界の重鎮であるアンドリュー・ンがいます。私も含め、彼のCourseraの講義にはお世話になった方が多いのではないでしょうか。

DeViSE

Socher et al.(2013)[6]をもう少しこなれた形にしたのが、DeViSE[7]です。DeViSEでは、外れ値検出を使わず、画像と単語の類似度を直接学習していきます。後で紹介するCLIPにかなり近いモデルです。

クロスモーダルのアプローチを使ったのはこの論文が初ではありませんが、要旨には「ラベル付き画像だけでは学習データの取得に限界があるから、テキストデータを活用して精度を高めよう」という意図が書かれています。私がZero-shotを調べたときに「なぜ学習データがないタスクに対して、text-supervisedのアプローチがデファクトスタンダードとして扱われているのか」という点が理解できませんでした。おそらくDeViSEのあたりから、このような方向性が確立されてきたのかな考えられます。

DeViSEの論文は、Googleのチームによって書かれたものですが、著者にはジェフ・ディーンがいます。Zero-shotはどれだけ界隈のレジェンドを惹き付けるのでしょうか。

Visual N-Grams

Visual N-Grams[8]は、現在のCLIPのように、インターネットからダウンロードした膨大な画像-テキストの組み合わせに注目した論文です。

自然言語処理では古くからN-Gramモデルが使われていますが、この考え方を画像に転用したものです。例えば、港にあるクレーンを1つとっても、人間は「navy yard」や「サンディエゴの港」のような多面的な認識をします。Visual N-GramsでもZero-shotへの言及はありますが、Visual N-Gramsの論点が現在のZero-shotを俯瞰するには有用だと思います。要旨からです。

Real-world image recognition systems need to recognize tens of thousands of classes that constitute a plethora of visual concepts. The traditional approach of annotating thousands of images per class for training is infeasible in such a scenario, prompting the use of webly supervised data. This paper explores the training of image-recognition systems on large numbers of images and associated user comments, without using manually labeled images.
実世界の画像認識システムでは、膨大な数の視覚的概念を構成する何万ものクラスを認識する必要があります。このようなシナリオでは、クラスごとに数千枚の画像にアノテーションを付けて学習する従来のアプローチは現実的ではなく、Webから学習したデータの利用が必要である。本論文では、人手でラベル付けした画像を用いずに、大量の画像と関連するユーザコメントから画像認識システムを学習する方法を検討する。

ポイントは2点あります。

  • 実世界の画像認識を考えると、(人間のような)何万ものクラスを多面的に認識したい。従来のMany-shotのアプローチはこれに向かない
  • アノテーションは、人がつけるのではなく、Webに付随しているデータ(コメント)から直接利用するべき

つまり、従来のMany-shotの問題設定としての限界、アノテーションを人間がしたくないことの2点が、この背景となる思想です。これを知ることで、現在のZero-shotの流れが理解しやすくなります。

Contrastive Learning

最初は次元削減から

今まで、Zero-shot Learning、クロスモーダルに関する論文をいくつか紹介してきましたが、CLIPを理解するためには「Contrastive Learning」について知っておく必要があります。そもそもCLとはどのようにして生まれたのでしょうか?

Contrastive Learningの原案は、これもディープラーニングがブームになる前から生まれていました。Hadsell et al.(2006)[9]は、「spring system」と呼ばれる次元削減のための学習システムを構築しました。これはもともと、高次元のデータを低次元の多様体にマッピングする次元削減の問題として提案されました。図はこちらの論文からです。

黒い丸は類似サンプル、白は似ていないサンプルです。このように似ているサンプル同士を近づけて、似ていないサンプル同士を離すという、バネのようなシステムであったから「spring system」と呼ばれました。spring systemという呼び方は現在ほとんどされませんが、考え方自体は現在のContrastive Learningと同じです。

なお、この著者にはヤン・ルカンがいます。Zero-shot Learningを追っているだけで、「ディープラーニングのゴッドファーザー」と呼ばれているチューリング賞受賞の3人:ヤン・ルカン、ヨシュア・ベンジオ、ジェフリー・ヒントンの論文をコンプリートできます。

SimCLR

ディープラーニングにおけるContrastive Learningの大きなブレイクスルーになったのがSimCLR[10]です。

2020年の論文で、spring boxからいきなり時代を飛び越えましたが、「サンプル間の類似度を学習し、類似したものを近づけようとする」という大枠は変わっていません。\mathcal{T}はData Augmentationで、同一のデータ\mathbb{x}をData Augmentationして、異なるサンプル\mathbb{\tilde{x_i}, \tilde{x_j}}を得ます。これを同じニューラルネットワークに通して、類似度を学習します。SimCLRはクロスモーダルではなく、画像で完結するためこのようなAugmentationを入れています。

SimCLRは、ディープラーニングの問題としてシンプルな枠組みで実現し、教師なし学習でも教師あり学習に迫る精度を達成したことから、大きな注目を集めました。SimCLRのような枠組みは自己教師あり学習(Self-supervised learning)とも呼ばれます。教師ありとはいうものの、自分自身を教師として使うため、従来の教師あり学習のように人間が作ったラベルは必要ありません。これは、先程紹介したクロスモーダルなモデルと発想が似ています。

SimCLRの著者には、またヨシュア・ベンジオがいます。先程紹介したZero-data Learningから14年越しです。これだけ長い期間、インパクトのある研究を出し続けられるのは本当に驚かされます。

なお、この記事を読んでいる方には「Contrastive Learningと、One-shot Learningの文脈で語られることの多いMetric Learningは何が違うのか?」という疑問も抱いた方もいるかもしれません。私もこの疑問を持っていたのですが、最近ではContrastive LearningにMetric Learningの要素を加えた研究[11]もあり、両者の境界が曖昧になっています。厳密には違いがあるのかもしれませんが、方言の違いぐらいの認識で良いと思います。

CLIP誕生

クロスモーダルなZero-shot Learning

これまで長い期間をかけて、

  • Zero-shot Learning
  • 画像と文章のクロスモーダル
  • Contrastive Learning

の3つを紹介しました。これらはすべて「CLIP[12]」を説明するためのパーツです。いよいよCLIPを見ていきましょう。

CLIPはOpenAIによって2021年に発表された激強論文です。Zero-shotかつクロスモーダルで、従来のImageNetで訓練済みのMany-shotのモデルよりも頑強性が高いことが大きく注目されています。以下の図はCLIPの論文からですが、目にした方も多いのではないでしょうか。

CLIPもContrastive Learningをしています。SimCLRでは、1つの画像にData Augmentationを適用して2種類の画像を作り、その類似度を最大化していました。CLIPでは、これが画像とテキストのクロスモーダルなので、1枚の画像とそれに対応するキャプションの類似度を最大化するようなContrastive Learningになります。画像のキャプションとはどういうものかというと、次のようなものです。これはOpen Images Dataset V6[13]のサンプルで、元画像の作者は[14](CC BY 2.0)です。

「Fractal Cauliflower」というのがキャプションです。インターネット上の画像(例:Flicker、Instagram)には、キャプションが付属していることが多いので、それを活用します。この例では、キャプションを通じて「幾何学的」と「カリフラワー」の2つの概念を学習できます。

ただし、テキストと画像の特徴量は同じネットワークでは計算できないため、それぞれ別のモデル(例:Transformer、Vision Transformer)を使います。例えば、バッチサイズをN、特徴量の次元をDとすると、画像とテキストネットワークの出力特徴量は、それぞれ(N, D)次元で表されます。これらの行列積を取り出し、(N, N)というコサイン類似度の行列を作り、対角成分を正例、それ以外を負例としてContrastive LearningするのがCLIPです。

この図も論文からの引用です。ポイントは、コサイン類似度の実装をL2-Normalizeで行われていることで、ベクトルa, bに対するコサイン(類似度)というのは、

\cos(a, b)=\frac{a\cdot b}{|a| |b|} = \frac{a}{|a|}\cdot\frac{b}{|b|}

で表されます。これは高校数学の教科書にも載っている公式ですが、右辺はL2-Normalizeそのものです。行列に拡張したのが上図のコードです。高校数学でおなじみの公式が、最先端の研究のキー技術になっているのは、学校で教えて欲しいぐらいです(機械学習で高校数学がいるよ、と言われるのはこういう理由です)。

ロジットの計算の部分で、さらっと温度付きの計算して、さらに学習パラメーターとするという面白いことをやっているのですが、それはおいておきましょう。

プロンプトエンジニアリング

CLIPでは、訓練時に画像の持っているキャプションを活用しましたが、推論時はどう考えているのでしょうか。推論時の画像はキャプションを持っていない場合も多いです(例:カメラから撮った画像)。CLIPではプロンプト(prompt)という、自然言語処理由来の独特な概念が出てきます。この図はOpenAIのブログ記事からです[15]

飛行機を推定するには、「a photo of a airplane / bird / bear / ...」といろいろなテキストがありますよね。このようにターゲットとなるテキストを複数提示して、最も類似度の高いテキストを選択するのが、CLIPの推論方法です。ここで「a photo of {label}」というテンプレートが、プロンプトです。プロンプトという単語は聞き慣れませんが、Windowsに出てくる「コマンドプロンプト」を連想すると馴染み深いのではないでしょうか。プロンプトを調整して精度を上げていくのがプロンプトエンジニアリングです。機械学習でおなじみの「特徴量エンジニアリング」の仲間として捉えると理解しやすいかもしれません。

CLIPでのプロンプトエンジニアリングは、この他に「A photo of a big {label}」「A photo of a small {label}」というように、複数のプロンプトを組み合わせてアンサンブルします。これが精度の改善に大きく寄与します。CLIPの論文からです。

プロンプトエンジニアリングとアンサンブルにより、ImageNetのZero-shot精度を5%、モデル計算量で4倍改善できたとのことです。ここはこの記事本来の目的ではないので、細かくは書きませんが、CLIPには「プロンプト」というクロスモーダル特有の概念が出てくるよ、ということを覚えておいてください。やっていることは単なるテンプレート構文です。

CLIPとZero-data Learnig、Visual N-Grams

今回は「CLIP」の紹介が主な目的ですが、CLIPより前の古い論文もいくつか掲載しました。その理由は、10年以上も前の古い論文で培われた思想が、CLIPでも生きていることを味わってほしかったのと、CLIPがどういった理念から作られたのかを知ってほしかったからです。私個人が純粋に興味あったからというのもあります。

さて、最初にZero-data Learningの論文を紹介しました。CLIPはZero-shot Learningでありながら、実はZero-data Learningの要素もかなり含んでいます。Zero-data Learningを再度引用してみましょう。

We introduce the problem of zero-data learning, where a model must generalize to classes or tasks for which no training data are available and only a description of the classes or tasks are provided.
ゼロデータ学習とは、学習データがなく、クラスやタスクの説明のみが提供されているクラスやタスクに対してモデルを汎化しなければならない問題である。

単に学習データがないというと誤解を招きますが、CLIPにはMany-shotのようなタスク固有の学習データがないのは事実です。ユーザーがFine-tuningしなくても、学習済みモデルに、飛行機の画像とプロンプトを与えれば、飛行機だと推論できます。「クラスとタスクの説明のみ提供されている」という点はまさにプロンプトです。

つまり、CLIPは「訓練データにはないクラスを認識する」というZero-shot Learningの文脈にありながら、Zero-data Learningをまるで伏線回収のように取り込んでいるのです。13年越しにこのような回収して、大きなブレイクスルーを導き出したのはかなり痺れるものがあります。

Visual N-Gramsは、CLIPの論文内でも比較対象として言及されるほど、強く意識していた論文です。Visual N-Gramsのポイントを再掲します。

  • 実世界の画像認識を考えると、(人間のような)何万ものクラスを多面的に認識したい。従来のMany-shotのアプローチはこれに向かない
  • アノテーションは、人がつけるのではなく、Webに付随しているデータ(コメント)から直接利用するべき

CLIPは、アノテーションの事情もふまえつつ、人間のような多面的かつ何万クラスの認識をしたいという考え方を受け継いでいるわけです。多面的な認識はプロンプトのラベルの単語を変えることで可能にしていますし、アノテーションはクロスモーダルなContrastive Learningによって解決しています。Visual N-Gramsの要旨で書かれたような、実世界の画像認識がCLIPの登場でより現実的になったということが言えるでしょう。

CLIPの欠点

万能のように書いたCLIPですが、実は欠点もあります。それはモデルの訓練に莫大な量のGPUとデータが必要なことです。CLIPの論文からです。

The largest ResNet model, RN50x64, took 18 days to train on 592 V100 GPUs while the largest Vision Transformer took 12 days on 256 V100 GPUs.
最大のResNetモデルであるRN50x64の学習には592台のV100 GPUで18日、最大のVision Transformerは256台のV100 GPUで12日かかりました。

ResNet50×64でV100が592×18=29.2年分も必要です。ResNet50×4に減らしても1回訓練するだけで1.8年分のGPUが必要です。CLIPのような強いモデルをスクラッチから訓練できるのが、ビックテックのような限られた存在になってしまい、技術を寡占されてしまわないかという懸念があります(訓練に億円単位がかかるモデルを自由に利用できるように公開する、仏のような心の持ち主が多数いるのを期待したいです)。

また、必要なデータ数も莫大です。後で説明する後続研究のSLIP[16]によれば、CLIPは4億枚の画像-テキストのデータを使用しており、これを1500万枚に減らしたところViT-BでImageNetのZero-shot精度が37.6%になったと報告されています[17]。300万枚では、同じモデルで17.1%まで落ちてしまいました。Zero-shotの精度を教師ありImageNet並に上げたければ、億単位の膨大なデータを使った訓練が必要という身も蓋もない話になります。

SLIPの論文によれば、ViT-B/16でCLIPを再実装するために、1500万枚に限定しても64台のV100で22.3時間(59.4日分)だったことが記されています。V100を2ヶ月使ってImageNetが40%行けばいい、というのはなかなか考えさせられるものがあります。

CLIPの先にあるもの

激強の論文として紹介したCLIPですが、訓練済みCLIPを使った研究が最近(2022年)とても活発です。CLIPのその後の論文をいくつか紹介していきましょう。

CLIPの改良

まずはダイレクトにCLIPの学習をより効率的にする方法です。1つ目はUC BerkeleyとFacebook AIのチームが2021年に発表したSLIP[16]で、図は論文からです。これはCLIPの画像側に自己教師あり学習を組み合わせ、先程紹介したSimCLRを画像側に追加したものです。CLIPよりも精度が高くなります。

2つ目はICLR 2022で発表されたDeCLIP[18]です。(1)モダリティ内、(2)複数のビューを通じたモダリティ間、(3)類似ペアから最近傍、それぞれの自己教師あり学習します。これにより、CLIPよりも7.1倍少ないデータ数で同じ精度を達成できています。

生成モデルへの応用

訓練済みCLIPと生成モデルは相性が良く、AnyFace[19]ではテキストを使って、ソースの顔写真から自在に合成できる手法です。

DALL・E 2[20][21]は、「馬に乗る宇宙飛行士」のようにテキストを入力すると画像を生成するモデルで、非常に精緻な画像が生成されることで話題になりました。この実装には訓練済みCLIPと拡散モデルが使われています。画像はOpenAIのサイトからです。

生成モデル以外にもCLIPの応用は多数あるのですが、これを紹介するときりがなくなるので割愛します。CLIPの活用がまさに研究の最先端で行われています。

SLIP(CLIPの後続研究)を動かしてみる

記事の締めくくりとして、Zero-shotのモデルを実際に動かしてみましょう。CLIPの推論は有用な記事が多数あるので、ここでは後続研究のSLIP[16]を動かします。CLIPのように使い勝手の良いAPIは整備されていませんが、公式のコード[17]をいじることで比較的簡単に実装できます。

ダウンロード

まずはSLIPの公式リポジトリからソースコードをダウンロードしてきます。PyTorchとtimmが別途必要です。

同じリポジトリから、訓練済みモデル(チェックポイント)をダウンロードします。いくつかモデルがありますが、ここでは「ViT-Base / SLIP / 100Epochs」(0-shot 45.0%)をダウンロードします。「Weights」のurlからダウンロードできます。2GB近いファイルなので、容量に気をつけてください。

先程クローンしたリポジトリの直下に「weights」フォルダを作り、そこに保存します。

チェックポイントの軽量化

ダウンロードしたチェックポイントは、推論に不要なパラメーターも含んでいるので、軽量化します。これにより、1.9GB→651MBに軽量化されます。元のチェックポイントは削除して構いません。

import torch
from collections import OrderedDict
from tokenizer import SimpleTokenizer
import models
from torchvision import transforms
from PIL import Image
import numpy as np

def strip_checkpoint(ckpt_path):
    ckpt = torch.load(ckpt_path, map_location='cpu')
    state_dict = OrderedDict()
    for k, v in ckpt['state_dict'].items():
        state_dict[k.replace('module.', '')] = v

    old_args = ckpt['args']
    print("=> creating model: {}".format(old_args.model))
    model = getattr(models, old_args.model)(rand_embed=False,
        ssl_mlp_dim=old_args.ssl_mlp_dim, ssl_emb_dim=old_args.ssl_emb_dim)
    model.cpu()
    model.load_state_dict(state_dict, strict=True)
    print("=> loaded resume checkpoint (epoch {})".format(ckpt['epoch']))

    torch.save(model, "weights/"+old_args.model)

strip_checkpoint("weights/slip_base_100ep.pt")

推論関数の作成

SLIPにはCLIPのように簡単に利用できる推論APIがないので、それっぽいものを作ります。公式コードをアレンジしたものです。

def load(model_name, device):
    model = torch.load("weights/"+model_name).to(device)

    image_preprocess = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        lambda x: x.convert('RGB'),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
    ])
    tokenizer = SimpleTokenizer()

    return model, image_preprocess, tokenizer

def inference(img_path, templates, classes, device="cuda:0"):
    model, image_process, tokenizer = load(
        "SLIP_VITB16", device)

    image = image_process(Image.open(img_path)).unsqueeze(0).to(device)

    text_features = []
    with torch.no_grad():
        # 言語特徴量
        for label in classes:
            texts = [t.format(label) for t in templates]
            texts = tokenizer(texts).to(device)
            texts = texts.view(-1, 77).contiguous()
            class_embeddings = model.encode_text(texts)
            class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
            class_embeddings = class_embeddings.mean(dim=0)
            class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
            text_features.append(class_embeddings)
        text_features = torch.stack(text_features, dim=0) # (n_classes, 512)

        # 画像特徴量
        image_features = model.encode_image(image)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        # コサイン類似度→ロジット→確率
        logits_per_image = model.logit_scale.exp() * image_features @ text_features.t()
        pred_prob = torch.softmax(logits_per_image, dim=-1).cpu().numpy()
        return pred_prob

img_pathは画像のパス、templatesはプロンプト、classesはクラス名のテキストを表します。プロンプトとクラス名を固定で複数の画像を推論したい場合は、言語特徴量やモデルをキャッシュさせるとよいでしょう。

一番いい推論を頼む

SLIPを使って画像を推論してみましょう。使用するサンプルは「そんな装備で大丈夫か?」で話題になったエルシャダイ[22]です。フリー素材として公開されている[23]ので、ありがたく利用します(画像クレジットはこちらで付与したものです)。

ではイーノックが何を着ているのかSLIPで判定してみましょう。ちゃんと装備をつけていると判定できるでしょうか?

# プロンプト
templates = {
    "a picture of a {}.",
    "a {} in a video game.",
    "a {} in an animation.",
    "a {} in a movie."
}
classes = ["man wearing uniform", "man wearing swimsuits", "man wearing armor", "naked man"]
result = inference("e3_luciferpv_1080p_free/02/e3_luciferpv_1080068.jpg", templates, classes)
print(classes)
print(np.round(result, 3))

プロンプトは4個与えてアンサンブルしています。クラスの候補としては、

  • man wearing uniform(制服を着ている)
  • man wearing swimsuits(水着を着ている)
  • man wearing armor(鎧を着ている)
  • naked man(裸)

の4つを与えました。結果は以下の通りです。

['man wearing uniform', 'man wearing swimsuits', 'man wearing armor', 'naked man']
[[0.121 0.001 0.877 0.   ]]

87.7%の確率で鎧を着ているとなりました。「大丈夫か?」と心配されるような装備でも無事に鎧と認識できました。

Zero-shotの面白いところは、追加の訓練もモデルを変更しなくても、問題の枠組みを変えられるところです。先程は「何を着ているか?」の分類でしたが、今度は「どんな行動しているか?」の分類にしてみます。

イーノックがキックしているシーンです。この写真だけでキックと認識できるでしょうか? 選択肢を次のように与えて同様に推論してみます。

  • man shooting(射撃している)
  • man kicking(キックしている)
  • man sleeping(寝ている)
  • man eating(食べている)
['man shooting', 'man kicking', 'man sleeping', 'man eating']
[[0.16  0.818 0.018 0.004]]

81.8%で蹴りと判定できました。shootingは攻撃モーションから、近いと判断されたのかもしれません、行動認識と組み合わせると面白そうですね。

まとめ

本記事では、クロスモーダルでZero-shot LearningのブレイクスルーであるCLIPが、過去の論文からどういう思想のもとで生まれたのかを俯瞰し、Zero-shotのモデルとしてSLIPの使い方を紹介しました。ディープラーニングのゴッドファーザー達が過去に生み出したアイディアが脈々と受け継がれ、CLIPによって一気に開花し、まさに今Zero-shotかつクロスモーダルなモデルが、社会実装へと向かいつつあります。このダイナミズムをぜひ一緒に体感しましょう。

引用

  1. Koch, Gregory, Richard Zemel, and Ruslan Salakhutdinov. "Siamese neural networks for one-shot image recognition." ICML deep learning workshop. Vol. 2. 2015.
  2. Palatucci, Mark, et al. "Zero-shot learning with semantic output codes." Advances in neural information processing systems 22 (2009).
  3. Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton. "Imagenet classification with deep convolutional neural networks." Advances in neural information processing systems 25 (2012).
  4. Tom Mitchell. "Machine Learning." McGraw Hill. 1997.
  5. Larochelle, Hugo, Dumitru Erhan, and Yoshua Bengio. "Zero-data learning of new tasks." AAAI. Vol. 1. No. 2. 2008.
  6. Socher, Richard, et al. "Zero-shot learning through cross-modal transfer." Advances in neural information processing systems 26 (2013).
  7. Frome, Andrea, et al. "Devise: A deep visual-semantic embedding model." Advances in neural information processing systems 26 (2013).
  8. Li, Ang, et al. "Learning visual n-grams from web data." Proceedings of the IEEE International Conference on Computer Vision. 2017.
  9. Hadsell, Raia, Sumit Chopra, and Yann LeCun. "Dimensionality reduction by learning an invariant mapping." 2006 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR'06). Vol. 2. IEEE, 2006.
  10. Chen, Ting, et al. "A simple framework for contrastive learning of visual representations." International conference on machine learning. PMLR, 2020.
  11. Chen, Shuo, et al. "Large-margin contrastive learning with distance polarization regularizer." International Conference on Machine Learning. PMLR, 2021.
  12. Radford, Alec, et al. "Learning transferable visual models from natural language supervision." International Conference on Machine Learning. PMLR, 2021.
  13. https://storage.googleapis.com/openimages/web/index.html
  14. https://www.flickr.com/people/tristanf/
  15. https://openai.com/blog/clip/
  16. Mu, Norman, et al. "SLIP: Self-supervision meets Language-Image Pre-training." arXiv preprint arXiv:2112.12750 (2021).
  17. https://github.com/facebookresearch/SLIP
  18. Li, Yangguang, et al. "Supervision exists everywhere: A data efficient contrastive language-image pre-training paradigm." arXiv preprint arXiv:2110.05208 (2021).
  19. Sun, Jianxin, et al. "AnyFace: Free-style Text-to-Face Synthesis and Manipulation." arXiv preprint arXiv:2203.15334 (2022).
  20. https://openai.com/dall-e-2/
  21. Ramesh, Aditya, et al. "Hierarchical Text-Conditional Image Generation with CLIP Latents." arXiv preprint arXiv:2204.06125 (2022).
  22. http://elshaddai.jp/elshaddai_crim/index.html
  23. http://elshaddai.jp/elshaddai_crim/freedeta.html