Android Developers Blog
The latest Android and Google Play news for app and game developers.
🔍
Platform Android Studio Google Play Jetpack Kotlin Docs News

02 October 2024

PyTorch machine learning models on Android


Link copied to clipboard
Posted by Paul Ruiz – Senior Developer Relations Engineer

Earlier this year we launched Google AI Edge, a suite of tools with easy access to ready-to-use ML tasks, frameworks that enable you to build ML pipelines, and run popular LLMs and custom models – all on-device. For AI on Android Spotlight Week, the Google team is highlighting various ways that Android developers can use machine learning to help improve their applications.

In this post, we'll dive into Google AI Edge Torch, which enables you to convert PyTorch models to run locally on Android and other platforms, using the Google AI Edge LiteRT (formerly TensorFlow Lite) and MediaPipe Tasks libraries. For insights on other powerful tools, be sure to explore the rest of the AI on Android Spotlight Week content.

To get started with Google AI Edge easier, we've provided samples available on GitHub as an executable codelab. They demonstrate how to convert the MobileViT model for image classification (compatible with MediaPipe Tasks) and the DIS model for segmentation (compatible with LiteRT).

a red Android figurine is shown next to a black and white silhouette of the same figure, labeled 'Original Image' and 'PT Mask' respectively, demonstrating image segmentation.
DIS model output

This blog guides you through how to use the MobileViT model with MediaPipe Tasks. Keep in mind that the LiteRT runtime provides similar capabilities, enabling you to build custom pipelines and features.

Convert MobileViT model for image classification compatible with MediaPipe Tasks

Once you've installed the necessary dependencies and utilities for your app, the first step is to retrieve the PyTorch model you wish to convert, along with any other MobileViT components you might need (such as an image processor for testing).

from transformers import MobileViTImageProcessor, MobileViTForImageClassification

hf_model_path = 'apple/mobilevit-small'
processor = MobileViTImageProcessor.from_pretrained(hf_model_path)
pt_model = MobileViTForImageClassification.from_pretrained(hf_model_path)

Since the end result of this tutorial should work with MediaPipe Tasks, take an extra step to match the expected input and output shapes for image classification to what is used by the MediaPipe image classification Task.

class HF2MP_ImageClassificationModelWrapper(nn.Module):

  def __init__(self, hf_image_classification_model, hf_processor):
    super().__init__()
    self.model = hf_image_classification_model
    if hf_processor.do_rescale:
      self.rescale_factor = hf_processor.rescale_factor
    else:
      self.rescale_factor = 1.0

  def forward(self, image: torch.Tensor):
    # BHWC -> BCHW.
    image = image.permute(0, 3, 1, 2)
    # RGB -> BGR.
    image = image.flip(dims=(1,))
    # Scale [0, 255] -> [0, 1].
    image = image * self.rescale_factor
    logits = self.model(pixel_values=image).logits  # [B, 1000] float32.
    # Softmax is required for MediaPipe classification model.
    logits = torch.nn.functional.softmax(logits, dim=-1)

    return logits

hf_model_path = 'apple/mobilevit-small'
hf_mobile_vit_processor = MobileViTImageProcessor.from_pretrained(hf_model_path)
hf_mobile_vit_model = MobileViTForImageClassification.from_pretrained(hf_model_path)
wrapped_pt_model = HF2MP_ImageClassificationModelWrapper(
hf_mobile_vit_model, hf_mobile_vit_processor).eval()

Whether you plan to use the converted MobileViT model with MediaPipe Tasks or LiteRT, the next step is to convert the model to the .tflite format.

First, match the input shape. In this example, the input shape is 1, 256, 256, 3 for a 256x256 pixel three-channel RGB image.

Then, call AI Edge Torch's convert function to complete the conversion process.

import ai_edge_torch

sample_args = (torch.rand((1, 256, 256, 3)),)
edge_model = ai_edge_torch.convert(wrapped_pt_model, sample_args)

After converting the model, you can further refine it by incorporating metadata for the image classification labels. MediaPipe Tasks will utilize this metadata to display or return pertinent information after classification.

from mediapipe.tasks.python.metadata.metadata_writers import image_classifier
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
from mediapipe.tasks.python.vision.image_classifier import ImageClassifier
from pathlib import Path

flatbuffer_file = Path('hf_mobile_vit_mp_image_classification_raw.tflite')
edge_model.export(flatbuffer_file)
tflite_model_buffer = flatbuffer_file.read_bytes()

//Extract the image classification labels from the HF models for later integration into the TFLite model.
labels = list(hf_mobile_vit_model.config.id2label.values())

writer = image_classifier.MetadataWriter.create(
    tflite_model_buffer,
    input_norm_mean=[0.0], #  Normalization is not needed for this model.
    input_norm_std=[1.0],
    labels=metadata_writer.Labels().add(labels),
)
tflite_model_buffer, _ = writer.populate()

With all of that completed, it's time to integrate your model into an Android app. If you're following the official Colab notebook, this involves saving the model locally. For an example of image classification with MediaPipe Tasks, explore the GitHub repository. You can find more information in the official Google AI Edge documentation.

moving image of Newly converted ViT model with MediaPipe Tasks
Newly converted ViT model with MediaPipe Tasks

After understanding how to convert a simple image classification model, you can use the same techniques to adapt various PyTorch models for Google AI Edge LiteRT or MediaPipe Tasks tooling on Android.

For further model optimization, consider methods like quantizing during conversion. Check out the GitHub example to learn more about how to convert a PyTorch image segmentation model to LiteRT and quantize it.

What's Next

To keep up to date on Google AI Edge developments, look for announcements on the Google for Developers YouTube channel and blog.

We look forward to hearing about how you're using these features in your projects. Use #AndroidAI hashtag to share your feedback or what you've built in social media and check out other content in AI on Android Spotlight Week!