エクサウィザーズ Engineer Blog

株式会社エクサウィザーズのエンジニアチームブログ

Real-time pose estimation in Android

This article is focused on Pose Estimation using TensorFlow Lite. I will guide you through every step from picking an ML model to displaying an output on the screen, with detailed explanations and materials for further reading. We will not dive deep into Machine Learning, however, as our primary goal is to learn how to use the tools provided by TensorFlow to accomplish the task of pose estimation. No prior Machine Learning experience is required, but it is assumed that you have some Java/Kotlin and Android proficiency. Without further ado, let’s get started!

Part 1: TensorFlow Lite

TensorFlow is an open source library for numerical computation and machine learning. It uses Python to provide an API for training and running ML models and deep neural networks, while executing operations in C++.

Data flow graphs are structures that describe how data moves through a series of processing nodes. Each node is a mathematical operation, and each node's input/output is a multidimensional data array, or a tensor.

Simply put, to receive an array of key points representing a human pose we need to format the initial image to match processing node's expected input and run it through a series of transformations described in a model - a process called inference.

TensorFlow Lite is a lightweight version of TensorFlow built specifically for mobile and embedded devices. It supports a set of core operations which have been tuned for performance while staying relatively lean in size. TFLite also provides an interpreter with hardware acceleration in Android (NNAPI). To learn more about TFLite and its constraints, please refer to this guide.

Quick start

To kick start your Android project, please check out the official documentation and this demo app:

Android guide

Pose Estimation demo

In short, to add tflite module to your project, modify your app's build.gradle as follows:

// Check the latest tensorflow-lite version at JCenter: 
// [https://bintray.com/google/tensorflow/tensorflow-lite](https://bintray.com/google/tensorflow/tensorflow-lite)
ext.tfliteVersion = '0.0.0-nightly'

android {
    defaultConfig {
        ndk {
            // include only relevant architectures to reduce apk size
            abiFilters 'armeabi-v7a', 'arm64-v8a'
        }
    }
}

dependencies {
    implementation 'org.tensorflow:tensorflow-lite:$tfLiteVersion'
}

The dependency contains core TFLite classes. Let's go over some of them one by one:

Interpreter - A class that helps with building and accessing a native interpreter, which is an interface between Java code and the core C++ tensor flow logic. In its constructor you can provide a file path to your pre-trained model and Interpreter.Options (more on that later).

Delegate - An interface for providing a native handle to a delegate - an executor that handles partial (or full) computation of a data flow graph.

Tensor - A representation of a multidimensional byte array containing input or output data.

Delegates

By default, all computation will be handled by the CPU. You can parallelize inference on CPU by setting the number of threads the task will run on:

val numThreads = 4 // depends on the number of cores the CPU has
val options = Interpreter.Options().apply { setNumThreads(numThreads) }

TensorFlow Lite provides 3 built-in delegates to run inference on:

GPU - provides a great increase in performance and power efficiency. I would recommend picking the GPU delegate as the default option, with the caveat that your device has to support OpenCL or OpenGL ES 3.1 and that not all operations are supported. You can read more about it in the official docs.

To add the GPU delegate to your project, add the following dependency:

// Check the latest gpu delegate version at JCenter
// https://bintray.com/google/tensorflow/tensorflow-lite-gpu
dependencies {
    implementation 'org.tensorflow:tensorflow-lite-gpu:$tfLiteVersion'
}

NNAPI - a delegate that utilizes Neural Networks API providing hardware acceleration on newer Android devices (API 27+). It is included in the tensorflow-lite package, so you don't need to add an extra dependency.

Hexagon - a substitution for the NNAPI delegate on older Android devices that do not fully support Neural Networks API.

Add the following dependency if you want to support older devices:

// Check the latest Hexagon version at JCenter
// https://bintray.com/google/tensorflow/tensorflow-lite-hexagon
ext.tfLiteHexagon = '0.0.0-nightly'
dependencies {
    implementation 'org.tensorflow:tensorflow-lite-hexagon:$tfLiteHexagon'
}

Below is the complete interpreter setup snippet:

/* (c) ExaWizards */

sealed class DelegateOptions {
    data class CPU(val numThreads: Int): DelegateOptions()
    object GPU: DelegateOptions()
    object NNAPI: DelegateOptions()
    object Hexagon: DelegateOptions()
}

fun createInterpreter(
        model: MappedByteBuffer,
        delegateOptions: DelegateOptions
): Interpreter {
    val options = Interpreter.Options().apply {
        when (delegateOptions) {
          DelegateOptions.CPU -> setNumThreads(numThreads)
        DelegateOptions.NNAPI -> setUseNNAPI(true)
        DelegateOptions.GPU -> addDelegate(GpuDelegate())
        DelegateOptions.Hexagon -> addDelegate(HexagonDelegate())
      }    
    }
    return Interpreter(model, options)
}

Support library (experimental)

The TensorFlow team provides an optional package with various utility classes to simplify image operations and tensor buffer processing. If you don't want to deal with bitmap manipulations and bit shifting, then give this library a shot!

Currently it's in beta, so please be careful when adding it to your main application. I'd recommend playing around with it in a side project to catch any potential shortcomings for your use case.

To add the dependency, modify your build.gradle

// Check the latest support library version at JCenter
// https://bintray.com/google/tensorflow/tensorflow-lite-support
ext.tfLiteSupportVersion = '0.1.0-rc1'
dependencies {
    implementation 'org.tensorflow:tensorflow-lite-support:$tfLiteSupportVersion'
}

Let's take a look at some of the classes and interfaces available:

ImageProcessor - a class that accumulates various transformations and applies them to a target TensorImage

ImageOperator - a base interface for TensorImage transformations, including:

  • Rot90Op - rotate an image by 90 degrees counter-clockwise N times.
  • ResizeOp - resize an image to match the target size. It performs scaling, so be careful to preserve your original aspect ratio.
  • ResizeWithCropOrPadOp - crop or pad an image to match your model's expected input size. It does not scale the original image, make sure to scale it down before applying this operator.

TensorOperator - a base interface for TensorBuffer transformations:

  • NormalizeOp - perform normalization - adjust buffer values to a common scale, usually in a range of [-1; 1].
  • QuantizeOp - perform quantization - map float values to a smaller set of integer numbers. It is used in quantized models to increase performance at the cost of precision.
  • DequantizeOp - reverse quantization.

Below is an example of building an ImageProcessor and transforming an image bitmap:

/* (c) ExaWizards */

val imageProcessor = ImageProcessor.Builder()
            .add(ResizeOp(scaledHeight, scaledWidth, ResizeOp.ResizeMethod.BILINEAR))
            .add(Rot90Op(numRotations))
            .add(ResizeWithCropOrPadOp(modelHeight, modelWidth))
        // f(x) = (x - 127.5) / 127.5; f(x) ∈ [-1; 1]; x ∈ [0; 255]
            .add(NormalizeOp(127.5f, 127.5f)) 
            .build()
val tensorImage = TensorImage.fromBitmap(bitmap)
val processedImage = imageProcessor.process(tensorImage)

TFLite wrapper (experimental)

If your model contains metadata, it enables you to use the TensorFlow Lite wrapper code generator. The Model wrapper eliminates the need to set up your delegates, manually performing image transformations and dealing with raw TensorBuffer output. The extent to which generated code will be helpful to you entirely depends on the completeness of the metadata. Also, keep in mind that this feature is in an experimental phase, so you'll probably have to wait until it becomes stable before replacing all your ML-related logic with generated code.

To learn more about the wrapper code generator, please refer to the official docs.

Part 2: Model

To accomplish our task - human pose estimation - it is crucial that we have a basic understanding of our ML model and learn about our expected inputs/outputs. The TFLite "Getting started" page and linked source code provide enough information to kick start a new Proof of Concept project, but if we are going to make any changes to the core logic or simply want to compare existing options - it's better to know what we're dealing with.

Picking the right model

Let's start by examining a repository of open-sourced ML models - TensorFlow hub. This is a great place to search for domain-specific, format-specific solutions.

Our search query would be an "image pose detection" domain with "model format" filter set to TFLite (as of June 2020, there's only one model satisfying this criteria - MobileNet_075). Now, we have two options: pure model or model + metadata. From tensorflow.org:

TensorFlow Lite metadata provides a standard for model descriptions. The metadata is an important source of knowledge about what the model does and its input / output information. The metadata consists of both - human readable parts which convey the best practice when using the model, and - machine readable parts that can be leveraged by code generators, such as the TensorFlow Lite Android code generator.

Let's further examine model with metadata. I found this useful tool to visualize the model's structure: Netron. After uploading the .tflite file we can see convolutional layers the model has and check what the expected inputs and outputs are:

f:id:ivanpo:20201006145032p:plain
Graph

f:id:ivanpo:20201006145103p:plain
Metadata

A closer look

First, let's understand the input requirements:

  • Image: FloatArray [1][353][257][3]

To prepare the image for classification, we'll need to scale it down to 353x257 pixels, extract each pixel's RGB value and normalize it, meaning the values should be within [-1;1].

Second, let's pay attention to the outputs:

  • Image (Grayscale): [1][23][17][17]

    An input image that has been reduced to 23x17 points, and each keypoint (out of 17 in total) has received a "confidence score"

  • Offsets: [1][23][17][34]

    Since the output matrix has a much smaller size, we want to get a better idea of where the original keypoint might have been. Offset vectors are here to help — once we pick the right (x, y) for keypoints, apply the following formula to get the final coordinates:

    y = keyPoint.y * originalHeight + offsets[0][keyPoint.y][keyPoint.x][keyPoint.index]

    x = keyPoint.x * originalWidth + offsets[0][keyPoint.y][keyPoint.x][keyPoint.index + 17]

  • Forward displacement: [1][23][17][64]

    Backward displacement: [1][23][17][1]

    In multi-pose estimation, when there are multiple poses to detect, it is not enough to pick a keypoint with the highest score — we need to pick multiple keypoints and group them into a graph representing a distinct human pose. Displacement arrays are used in a fast greedy decoding algorithm explained in this paper: PersonLab: Person Pose Estimation and Instance Segmentation with a Bottom-Up, Part-Based, Geometric Embedding Model. I will discuss the implementation later in the series.

The output structure seems to correspond with TensorFlow Pose Estimation starter guide:

  • Heatmaps: [1][height][width][N]
  • Offsets: [1][height][width][N * 2]
  • Forward displacements: [1][height][width][E * 2]
  • Backward displacements: [1][height][width][E * 2]

You might have noticed that something doesn't add up. The backward displacements matrix should be the same shape as the forward displacements: [1][23][17][64], but instead we are getting [1][23][17][1]. I believe it's a known problem (it is mentioned on StackOverflow), however it only affects multi-pose estimation. For single-pose estimation we will be using a much simpler "brute-force" solution that doesn't involve part-based graph traversal.

Part 3: Inference

Now that I‘ve given an overview of TFLite, models and available support tools, it's time to dive into the process of inference. The goal is to feed a prepared TensorImage to an interpreter and extract 17 key points with their (x, y) location and probability (confidence).

Preparation

If you manually downloaded the right model for your task, I recommend placing it in the /assets folder. If you don't want to check the file into VCS, simply add it to .gitignore and use this handy Gradle script, which will download the file automatically at build time:

/* (c) ExaWizards */

// download.gradle
def targetFile = "src/main/assets/posenet_model_meta.tflite"
def modelFloatDownloadUrl = "https://tfhub.dev/tensorflow/lite-model/posenet/mobilenet/float/075/1/metadata/1?lite-format=tflite"

task downloadModelFloat(type: DownloadUrlTask) {
    doFirst {
        println "Downloading ${modelFloatDownloadUrl}"
    }
    sourceUrl = "${modelFloatDownloadUrl}"
    target = file("${targetFile}")
}

class DownloadUrlTask extends DefaultTask {
    @Input
    String sourceUrl

    @OutputFile
    File target

    @TaskAction
    void download() {
        ant.get(src: sourceUrl, dest: target)
    }
}

preBuild.dependsOn downloadLibs

// Add this line to your build.gradle
// apply from:'download.gradle'

Now, we need to set up an interpreter. Depending on your target devices and benchmarks, you may choose one of the few available delegates, and load the model from the assets folder:

/* (c) ExaWizards */

private fun createInterpreter(device: Model.Device): Interpreter {
        val options = Interpreter.Options().apply {
            when (device) {
                Model.Device.CPU -> setNumThreads(numThreads)
                Model.Device.GPU -> addDelegate(GpuDelegate())
                Model.Device.NNAPI -> setUseNNAPI(true)
            }
        }
        return Interpreter(loadModelFile("posenet_model_meta.tflite", context), options)
}

private fun loadModelFile(path: String, context: Context): MappedByteBuffer {
        val fileDescriptor = context.assets.openFd(path)
        val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
        return inputStream.channel.map(
            FileChannel.MapMode.READ_ONLY, fileDescriptor.startOffset, fileDescriptor.declaredLength
        )
}

To avoid getting errors during model loading, add this to your app's build.gradle to disable .tflite file compression:

android {
    ...
    aaptOptions {
        noCompress "tflite"
  }
}

Assuming you've already prepared TensorImage (check Part 4 for more info), let's proceed with inference.

Single pose estimation

The interpreter takes an input array of ByteBuffer with a Tensor shape defined by the model; in our case it's [1, 353, 257, 3]. The output array will contain four 4-dimensional float arrays: Heatmaps, Offsets, Forward displacements, Backward displacements. You can get their default shapes by calling

getInterpreter().getOutputTensor(i).shape(),

where i ∈ [0, 3], as we have 4 output tensors.

/* (c) ExaWizards */

val outputMap = mutableMapOf<Int, Any>()

fun estimatePose(tensorImage: TensorImage): Person {
        val inputArray = arrayOf(tensorImage.buffer)
        (0 until interpreter.outputTensorCount).forEach {
            outputMap[it] = reshapeTo4dArray(interpreter.getOutputTensor(it).shape())
        }
        interpreter.runForMultipleInputsOutputs(inputArray, outputMap)
                // parse outputMap
                return extractKeyPoints(outputMap, tensorImage.width, tensorImage.height)
}

private fun reshapeTo4dArray(shape: IntArray): Array<Array<Array<FloatArray>>> =
        Array(shape[0]) { Array(shape[1]) { Array(shape[2]) { FloatArray(shape[3]) } } }

Next step is to extract key points and create a Person object that contains all of the information we need to draw a person's shape on-screen. Since we are focusing on single pose estimation for now, we will only need two arrays: Heatmaps and Offsets. The idea is to find the locations of the key points with the highest confidence scores, calculate their (x, y) coordinates using offset adjustment and normalize the confidence score to the range [0;1].

/* (c) ExaWizards */

// order is important!
enum class BodyPart {
    NOSE, LEFT_EYE, RIGHT_EYE, LEFT_EAR, RIGHT_EAR, LEFT_SHOULDER, RIGHT_SHOULDER,
    LEFT_ELBOW, RIGHT_ELBOW, LEFT_WRIST, RIGHT_WRIST, LEFT_HIP, RIGHT_HIP,
    LEFT_KNEE, RIGHT_KNEE, LEFT_ANKLE, RIGHT_ANKLE
}

data class Position(val x: Int, val y: Int)
data class KeyPoint(val bodyPart: BodyPart, val position: Position, val score: Float)
data class Person(val keyPoints: List<KeyPoint>, val score: Float)

@Suppress("UNCHECKED_CAST")
private fun extractKeyPoints(
    outputMap: Map<Int, Any>,
    imageWidth: Int,
    imageHeight: Int
): Person {
    val heatMaps = outputMap[0] as Array<Array<Array<FloatArray>>>
    val offsets = outputMap[1] as Array<Array<Array<FloatArray>>>

    val height = heatMaps[0].size
    val width = heatMaps[0][0].size
    val numKeyPoints = heatMaps[0][0][0].size

    val keyPoints = mutableListOf<KeyPoint>()
    val bodyParts = enumValues<BodyPart>()
    var totalConfidence = 0f
    for (keyPoint in 0 until numKeyPoints) {
        var maxVal = heatMaps[0][0][0][keyPoint]
        var maxRow = 0
        var maxCol = 0
        // Find the (row, col) locations of where the keyPoints are most likely to be.
        for (row in 0 until height) {
            for (col in 0 until width) {
                if (heatMaps[0][row][col][keyPoint] > maxVal) {
                    maxVal = heatMaps[0][row][col][keyPoint]
                    maxRow = row
                    maxCol = col
                }
            }
        }
        val yDisplacement = offsets[0][maxRow][maxCol][keyPoint]
        val xDisplacement = offsets[0][maxRow][maxCol][keyPoint + numKeyPoints]
        val yCoord = maxRow / (height - 1).toFloat() * imageHeight + yDisplacement
        val xCoord = maxCol / (width - 1).toFloat() * imageWidth + xDisplacement
        val confidence = sigmoid(maxVal)
        val bodyPart = bodyParts[keyPoint]
        totalConfidence += confidence
        keyPoints.add(KeyPoint(bodyPart, Position(xCoord.toInt(), yCoord.toInt()), confidence))
    }

    return Person(keyPoints, totalConfidence / numKeyPoints)
}

/** Returns a value within [0,1].   */
private fun sigmoid(x: Float): Float {
    return (1.0f / (1.0f + exp(-x)))
}

And there we have it - a Person object containing key point locations and their confidence scores! The next step would be to filter key points by a confidence threshold and translate the coordinates back to the starting image dimensions - remember, we applied a number of transformations (rotation, scale, crop) to the original input. I will discuss this logic later in the series, using a CameraX feed as an example.

Multi-pose estimation

If we want to get more than one person's key points, the brute-force key point search solution above will not work. As I mentioned before, we have to use forward and backward displacement arrays to handle this task.

The idea of a modified algorithm is described in this PersonLab paper:

f:id:ivanpo:20201006145205p:plain
Multipose algorithm

As you can see, the algorithm is non-trivial and requires a bit of time to get right. You can try implementing it yourself, or use one of these open source projects as an example:

PoseNet Typescript by TensorFlow

PoseNet Java by shaqian

Important note: Before you decide to enable multi-pose estimation, make sure your model supports it! The current model listed on TensorFow Hub returns incorrect displacement arrays, so try using a modified version from this StackOverflow answer instead.

f:id:ivanpo:20201006145314p:plain
Multipose output

Part 4: Camera 1̶ 2̶ X

Android CameraX is a great library used to seamlessly integrate camera logic into the project's codebase by combining existing use cases that interface with the device's camera API: Preview, Image Analysis, Image Capture. If you're not familiar with the CameraX architecture, please refer to the official documentation page.

f:id:ivanpo:20201006145348p:plain
from CameraX documentation

In this part we will focus on combining Preview with Image Analysis to display an inferred human pose on screen in real time.

Preparation

To get started with CameraX and get a better idea of its architecture and capabilities, I recommend following Google’s codelab page. I you want a quick start by looking at a complete implementation, you can refer to my PoseNet sample (coming soon).

Image Analysis

Once you're familiar with the CameraX API, let's start by setting up an ImageAnalysis use case. First, we might want to request a specific resolution by calling

val builder = ImageAnalysis.Builder().setTargetResolution(Size(width, height))

Keep in mind that in order to infer a human pose in real time we will need to heavily downscale our original image to match the Model's input size. However, we can't just request any arbitrary resolution; instead, it will depend on the Camera implementation and will fall back to the nearest available resolution in case the requested size doesn't exist.

Next, let's set an appropriate backpressure strategy. Inferrence takes time, so we won't be able to process every frame from the camera feed before the next one comes in. To avoid buffer overflow, we will skip subsequent frames until we're done processing the current one:

builder.setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST)

Finally, let's create an Analyzer. Image.Analyzer will receive an ImageProxy object which we will use to get an Image and transform it using the ImageProcessor class provided by the TensorFlow Lite support library.

Here's sample code for ImageAnalysis setup

/* (c) ExaWizards */

val useCase: ImageAnalysis = ImageAnalysis.Builder()
    .setTargetResolution(Size(targetWidth, targetHeight))
    .setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST)
    .build()
    .apply(::setAnalyzer)

private fun setAnalyzer(imageAnalysis: ImageAnalysis) {
    imageAnalysis.setAnalyzer(
        Executors.newSingleThreadExecutor(),
        ImageAnalysis.Analyzer { image ->
            val transformedImage = image.use {
                processImage(
                    image.image ?: throw Exception("Unexpected ImageProxy"),
                    image.imageInfo.rotationDegrees,
                    modelConfig.modelWidth,
                    modelConfig.modelHeight
                )
            }
            val person = estimatePose(transformedImage.tensorImage)
            onPoseData(PoseData(
                person,
                transformedImage.originalSize,
                transformedImage.scaledSize,
                transformedImage.paddedSize,
                transformedImage.orientation)
            )
        }
    )
}

data class PoseData(
    val person: Person,
    val originalSize: Size,
    val scaledSize: Size,
    val paddedSize: Size,
    val orientation: Orientation,
    val transformedBitmap: Bitmap?
)

Important note: if you're using the GPU delegate for inference, remember that only the original thread that instantiated a GPU delegate can call it. Here, I'm using Executors.newSingleThreadExecutor() as an image processing executor and lazily creating a GPU instance. That means I cannot reuse the same delegate once I discard the ImageAnalysis object and have to instantiate a new delegate again.

Image Transformation

To prepare an image for inference we need to perform the following series of transformations:

Downscaling → Rotation → Cropping → Normalization

In order to translate the resulting pose coordinates back to the original dimensions, I recommend keeping each step's variables in a data class — that way it will be easier to apply each transformation in reverse order.

Important note: CameraX provides an Image in YUV_420_888 format, which we will convert to RGB values in order to extract a byte buffer for further image processing with PoseNet. I am using RenderScript for YUV → RGB conversion; you can take a look at the "sample approach" here.

The TensorFlow Lite support library provides helper operations discussed earlier, each resulting in creating a new TensorImage that holds a modified Bitmap. A complete image processing function looks something like this:

/* (c) ExaWizards */

private val yuvToRgbConverter = YuvToRgbConverter(context.applicationContext)

data class TransformedImage(
    val tensorImage: TensorImage,
    val originalSize: Size,
    val scaledSize: Size,
    val paddedSize: Size,
    val orientation: Orientation
)

fun processImage(
    image: Image,
    rotationDegrees: Int,
    targetWidth: Int, // input tensor size
    targetHeight: Int // input tensor size
): TransformedImage {
    val imageBitmap = Bitmap.createBitmap(image.width, image.height, Bitmap.Config.ARGB_8888)
    yuvToRgbConverter.yuvToRgb(image, bitmap)
    val numRotations = rotationDegrees / 90
    val scale = min(image.height.toDouble() / targetWidth, image.width.toDouble() / targetHeight)
    val scaledSize = Size((image.width / scale).toInt(), (image.height / scale).toInt())
    val orientation = if (numRotations % 2 == 0) {
        Orientation.HORIZONTAL
    } else {
        Orientation.VERTICAL
    }
    val imageProcessor = ImageProcessor.Builder()
        .add(ResizeOp(scaledSize.height, scaledSize.width, ResizeOp.ResizeMethod.BILINEAR))
        .add(Rot90Op(-numRotations))
        .add(ResizeWithCropOrPadOp(targetHeight, targetWidth))
        .add(NormalizeOp(127.5f, 127.5f))
        .build()
    val tensorImage = TensorImage.fromBitmap(imageBitmap)
    return TransformedImage(
        imageProcessor.process(tensorImage),
        Size(image.width, image.height),
        scaledSize,
        Size(targetWidth, targetHeight),
        orientation
    )
}

Coordinate translation

The final step is to extract the inferred pose‘s key points and apply the coordinate translation algorithm to match the camera's preview layout. The tricky part is to add (x, y) padding in case your pose overlay view aspect ratio doesn't match the original image. The CameraX preview window will do the same, and the effect is similar to ImageView's centerCrop scale type. Let's add this extension function:

/* (c) ExaWizards */

private val minConfidence = 0.7f

fun PoseData.extractKeyPoints(val width: Int, val height: Int): Map<BodyPart, PointF> {
    val scaledWidth: Int
    val scaledHeight: Int
    val originalWidth: Int
    val originalHeight: Int
    when (orientation) {
        Orientation.HORIZONTAL -> {
            scaledWidth = scaledSize.width
            scaledHeight = scaledSize.height
            originalWidth = originalSize.width
            originalHeight = originalSize.height
        }
        Orientation.VERTICAL -> {
            scaledWidth = scaledSize.height
            scaledHeight = scaledSize.width
            originalWidth = originalSize.height
            originalHeight = originalSize.width
        }
    }
    val xOffset = (scaledWidth - paddedSize.width) / 2.0
    val yOffset = (scaledHeight - paddedSize.height) / 2.0

    // crop or pad to fit current view
    val originalRatio = originalHeight / originalWidth.toDouble()
    val widthFactor: Double
    val heightFactor: Double
    val xPad: Double
    val yPad: Double
    if (width * originalRatio >= height) {
        // width is the basis
        xPad = .0
        yPad = (height - width * originalRatio) / 2
        widthFactor =
            (width / originalWidth.toDouble()) * originalWidth / scaledWidth.toDouble()
        heightFactor =
            (width * originalRatio / originalHeight.toDouble()) * originalHeight / scaledHeight.toDouble()
    } else {
        xPad = (width - height / originalRatio) / 2
        yPad = .0
        widthFactor =
            ((height / originalRatio) / originalWidth.toDouble()) * originalWidth / scaledWidth.toDouble()
        heightFactor =
            (height / originalHeight.toDouble()) * originalHeight / scaledHeight.toDouble()
    }

    return person.keyPoints
            .asSequence()
            .filter { it.score > minConfidence }
            .map {
                it.bodyPart to it.position.toAdjustedPoints(
                    widthFactor,
                    heightFactor,
                    xOffset,
                    yOffset,
                    xPad,
                    yPad
                )
            }
            .toMap()
}

private fun Position.toAdjustedPoints(
    widthFactor: Double,
    heightFactor: Double,
    xOffset: Double,
    yOffset: Double,
    xPad: Double,
    yPad: Double
) = PointF(
    ((x + xOffset) * widthFactor + xPad).toFloat(),
    ((y + yOffset) * heightFactor + yPad).toFloat()
)

That's it! Now all you need to do is to invalidate() the view on every update from ImageAnalyzer and draw a circle where each of the extracted key points are:

/* (c) ExaWizards */

// inside PoseOverlayView.kt

private var pointMap: Map<BodyPart, PointF> = emptyMap()
    set(value) {
        field = value
        invalidate()
    }

private val circleRadius = 8.0f
private val circlePaint: Paint = Paint().apply {
    color = Color.WHITE
    strokeWidth = 8.0f
}

fun updatePoseData(poseData: PoseData) {
    pointMap = poseData.extractKeyPoints()
}

override fun onDraw(canvas: Canvas?) {
    super.onDraw(canvas)
    canvas ?: return
    canvas.drawColor(Color.TRANSPARENT, PorterDuff.Mode.CLEAR)
    pointMap.forEach { entry ->
        entry.values.forEach { canvas.drawCircle(it.x, it.y, circleRadius, circlePaint) }
    }
}

Part 5: Definition of Done

Previously we learned how to set up an interpreter, pick the right model, how to attach the CameraX analyzer and draw the output on a canvas. There’s one more thing left to cover: how to improve user experience depending on your use case. You may need more precise key point estimation, or, maybe, fast inference time is critical for a smooth UX. In this part we will discuss some tips and tricks that may be worth considering.

Optimizing for accuracy

Posenet is a fully convolutional model, meaning it was trained with a specific image size but can process larger images, sacrificing performance in favor of accuracy. The only rule is that the size should be a multiple of 16, plus 1 (see this answer). Previously, we talked about the expected input/output tensor’s shape: [1, 353, 257, 3] for the input and [1, 23, 17, X] for the various output tensors. As you may remember, input shape represents the amount of input image pixels times 3 (one Float per each RGB-channel). the output shape scales linearly with an outputStride: outWidth = ((inputWidth - 1) / outputStride) + 1, where the outputStride can be 8, 16 or 32. The lower the outputStride, the higher the accuracy, but the slower the speed.

A pre-trained .tflite model does not support a variable output stride, but we can change the input tensor shape and adjust our expectations for the output tensor. Here’s how to do it:

/* (c) ExaWizards */

//create an interpreter first
val interpreter: Interpreter = Interpreter(model, options)

// let's double the size of the default tensor
fun resizeInput() {
    interpreter.resizeInput(0, intArrayOf(1, 705, 513, 3))
}

// remember to scale a processed image size to 705x513 instead of 353x257
fun <T> estimatePose(byteBuffer: ByteBuffer, decoder: Decoder<T>): T {
    val inputArray = arrayOf(byteBuffer)
        // output shapes will become [1, 45, 33, X]
    model.run(inputArray, outputs.buffer)
    return decoder.decode(create4DArray(outputs))
}

Important note: remember that inference time does not scale linearly. On my Pixel 1 test device, using the GPU delegate, I was able to get ~70ms average inference, while doubling the input size brought the time up to ~270ms!

This method is useful if you don’t care about real-time performance and instead are analyzing a static image while running some scene transition animation or showing a brief loading screen after taking a picture.

Optimizing for performance

If we can afford to sacrifice accuracy to gain true real-time pose estimation even on lower-end devices, it might be a good idea to scale the image down to even smaller size. Remember to adjust your input/output tensor shape accordingly.

One other bit of advice I can give you is to optimize the image processing part. During my tests on Pixel 1 I was using the TFLite support library, and the image processing took up to ~60ms on average, almost the same time as inference itself! Here's what it looked like:

/* (c) ExaWizards */

val imageProcessor = ImageProcessor.Builder()
    .add(ResizeOp(scaledSize.height, scaledSize.width, ResizeOp.ResizeMethod.BILINEAR))
    .add(Rot90Op(-numRotations))
    .add(ResizeWithCropOrPadOp(targetHeight, targetWidth))
    .add(NormalizeOp(127.5f, 127.5f))
    .build()
val tensorImage = TensorImage.fromBitmap(imageBitmap)
val tensorBuffer = imageProcessor.process(tensorImage).tensorBuffer

Under the hood each ImageOperator produces a new Bitmap by applying a transformation to the original image, and the last operation in the chain transforms a Bitmap into a ByteBuffer and performs normalization on it. Let's take a look at how we can optimize this:

  • Combine ResizeOp with Rot90Op
  • Leave ResizeWithCropOrPadOp as is
  • Combine Bitmap → ByteBuffer with NormalizeOp

You can create your own operators by implementing the ImageOperator and TensorOperator interfaces, which are a part of the TFLite support library, but I will show you a sample image transformation without ImageProcessor to better understand how it works:

/* (c) ExaWizards */

val rotateMatrix = Matrix()
val scale = min(
    image.height.toDouble() / targetWidth,
    image.width.toDouble() / targetHeight
)
val scaledSize = Size((image.width / scale).toInt(), (image.height / scale).toInt())
val sx: Float = scaledSize.width / image.width.toFloat()
val sy: Float = scaledSize.height / image.height.toFloat()
// combine ResizeOp with Rot90Op
rotateMatrix.preScale(sx, sy)
rotateMatrix.postRotate(rotationDegrees.toFloat())
val rotatedBitmap = Bitmap.createBitmap(
        imageBitmap, 0, 0, imageBitmap.width, imageBitmap.height,
        rotateMatrix, true
    )

// see ResizeWithCropOrPadOp.java for implementation
val croppedBitmap = cropBitmap(rotatedBitmap, targetHeight, targetWidth)

// extract RGB values and normalize them
val mean = 128f
val std = 128f
val bytesPerChannel = 4
val inputChannels = 3
val batchSize = 1
val inputBuffer = ByteBuffer.allocateDirect(
    batchSize * bytesPerChannel * croppedBitmap.height * croppedBitmap.width * inputChannels
)
inputBuffer.order(ByteOrder.nativeOrder())
inputBuffer.rewind()
val intValues = IntArray(croppedBitmap.width * croppedBitmap.height)
croppedBitmap.getPixels(intValues, 0, croppedBitmap.width, 0, 0, croppedBitmap.width, croppedBitmap.height)
for (pixelValue in intValues) {
    inputBuffer.putFloat(((pixelValue shr 16 and 0xFF) - mean) / std)
    inputBuffer.putFloat(((pixelValue shr 8 and 0xFF) - mean) / std)
    inputBuffer.putFloat(((pixelValue and 0xFF) - mean) / std)
}
return inputBuffer

By applying this simple improvement I was able to save ~25ms on average, bringing the image processing time down to ~35ms.

Frame interpolation

My final tip for you is about providing users with a smooth UX even if your computational budget is relatively low.

Like I mentioned before, the Pixel 1 is not the most performant device to run inference on, with an average time of ~100ms (including image processing) using default tensor shapes. That means every pose update will take at least 100ms to appear on screen, resulting in an average of 10 frames per second. What should we do if we simply can't go faster, but still want smooth 60fps updates?

In that case I suggest using a trick involving interpolation. The idea is that, whenever a new pose update comes, instead of drawing a new frame immediately, we start gradually moving existing points to their new destination over time, creating the illusion of smooth updates. If an update happens before the points reach their previous destination, simply start a new intrepolator from their current position to the new one. It's important to remember that this trick will introduce an artificial delay and will de-sync the camera feed and pose overlay view, making the experience arguably worse on more performant devices (i.e., capable of at least 30fps updates). Still, you can make the interpolation time dynamic and adjust it at runtime based on how much time the last inference took to complete.

/* (c) ExaWizards */

// in FluidPoseView.kt
...
private var pointMap: MutableMap<BodyPart, PointF> = mutableMapOf()
private val interpolator = LinearInterpolator()
private val flow = MutableStateFlow<MutableMap<BodyPart, PointF>?>(null)
private val coroutineScope: CoroutineScope? = (context as? AppCompatActivity)?.lifecycleScope
private var animJob: Job? = null
private val durationNanos = 1e8f

private val evaluator = object : TypeEvaluator<MutableMap<BodyPart, PointF>> {
    private val pointFEvaluator: PointFEvaluator = PointFEvaluator()

    override fun evaluate(
        fraction: Float,
        startValue: MutableMap<BodyPart, PointF>?,
        endValue: MutableMap<BodyPart, PointF>?
    ): MutableMap<BodyPart, PointF> {
        val updated = startValue?.mapValues { entry ->
            val startPointF = entry.value
            val endPointF = endValue?.get(entry.key)
            when {
                startPointF == zeroPoint -> endPointF ?: zeroPoint
                endPointF == null -> zeroPoint
                else -> pointFEvaluator.evaluate(fraction, startPointF, endPointF)
            }
        }?.toMutableMap() ?: mutableMapOf()
        endValue?.forEach {
            updated.addIfAbsent(it.key, it.value)
        }
        return updated
    }
}

override fun onAttachedToWindow() {
    super.onAttachedToWindow()
    animJob = coroutineScope?.launch {
        flow.collectLatest { endValue ->
            endValue ?: return@collectLatest
            val startValue = pointMap
            val startTime = System.nanoTime()
            while (true) {
                val time = awaitFrame()
                val fraction = (time - startTime) / durationNanos
                if (fraction >= 1.0f) {
                    break
                }
                val interpolatedFraction = interpolator.getInterpolation(fraction)
                pointMap = evaluator.evaluate(interpolatedFraction, startValue, endValue)
                invalidate()
            }
        }
    }
}

override fun onDetachedFromWindow() {
    super.onDetachedFromWindow()
    animJob?.cancel()
}

override fun onDraw(canvas: Canvas?) {
    // pointMap.values
    //        .filter { it != zeroPoint }
    //        .forEach { canvas.drawCircle(it.x, it.y, circleRadius, circlePaint) }
}

f:id:ivanpo:20201006150119g:plain
Low performance, discrete frames

f:id:ivanpo:20201006150229g:plain
Low performance, interpolated frames

Conclusion

We learned how to integrate TensorFlow Lite into your project, explored the TensorFlow support library package, analyzed the Posenet model, discussed what the inference is and leveraged the CameraX API to efficiently analyze the camera feed in real time. You can apply many of the concepts discussed here to other use cases, too, and to give you a quick start I will prepare an open source sample project showcasing the on-device machine learning kit.

Thanks for your time!