Pooya Abolghasemi bio photo

Pooya Abolghasemi

Passionate about Machine Learning, Robotics, Computer Vision, Technology, Astronomy.

Twitter   G. Scholar LinkedIn Github e-Mail


We start from the android sample project for image classification. The project classifies Android’s camera images into ImageNet classes. We are going to modify the TFLite models and the Android java code to return the extracted features along with the classification probabilities.

Customize the TFLite Model

The sample Android project will automatically download and use a predefined TFLite model. If you want to try a different TFLite model, you can download one from here. For classification purposes, Quantized (smaller, faster, less accurate) and Floating point(bulky, slower, better performance) models are available. Download a model based on your requirements. In the compressed file, you will find 7 files. I downloaded the MobileNetV2 floating-point model. Note that you should change the commands based on the model you download throughout this post.


(protobuf files - contains the frozen model)

(TFLite model - can be attained by converting the frozen model)

The compressed file contains checkpoint files so one can load the model and modify the layers. To be able to load the model, a file named checkpoint with the following content is needed.

model_checkpoint_path: "mobilenet_v2_1.0_224.ckpt"
all_model_checkpoint_paths: "mobilenet_v2_1.0_224.ckpt"

Check out this post to learn about loading the checkpoint files and freeze them for later use. For the purpose of this post, no modification is needed to the model’s architecture. To convert a model to a TFLite model, first we need to freeze it. The frozen model can be then converted to TFLite using the tflite_convert script. Starting from TF 1.9 the tflite_convert is installed as a part of the TF python package. Check out this official post for more information regarding tflite_convert. We are going to add an intermediate layer as an additional output. Adding an additional output to the TFLite model is possible even when the checkpoint files are not provided. We only need the *.pb files to modify the outputs.

To add an intermediate layer to the model’s outputs, we need to choose a layer first. The file mobilenet_v2_1.0_224_eval.pbtxt contains the layer’s information but it is very big and hard to read. Alternatively, you can use Tensorboard to load the pbtxt file and get a better sense of the network’s architecture. I chose the layer MobilenetV2/Logits/AvgPool with the shape (1, 1, 1280).

Here is the tflite_convert command to export the new TFLite model with an additional output.

    --output_file=customized.tflite \
    --graph_def_file=mobilenet_v2_1.0_224_frozen.pb \
    --input_arrays=input \
    --input_shapes=1,224,224,3 \

Customizing the Android Project

First, let’s modify the model’s type spinner values in the file app/src/main/res/values/strings.xml and remove the Quantized option since we only converted our model with the floating-point kind. We also need to change the model’s path in the file ClassifierFloatMobileNet.java.

  protected String getModelPath() {
    // you can download this file from
    // see build.gradle for where to obtain this file. It should be auto
    // downloaded into assets.
    return "customized.tflite";

All other changes will happen in the file Classifier.java. The function recognize_image in Classifier.java will return a list of Recognitions. Recognition is a class defined in the same file. To keep the modifications as little as possible, let us modify the Recognition class and add a new field to it to store the features from the AvgPool layer. We will save the intermediate features inside the first Recognition in the list of Recognitions returned by the recognize_image function.

  /** An immutable result returned by a Classifier describing what was recognized. */
  public static class Recognition {
    private final String id;
    private final String title;
    private final Float confidence;
    private RectF location;
    private float[] features;

    public Recognition(String id, String title, Float confidence, RectF location, float[] features) {
      this.id = id;
      this.title = title;
      this.confidence = confidence;
      this.location = location;
      this.features = features;

    public Recognition(final String id, final String title, final Float confidence, final RectF location) {
      this.id = id;
      this.title = title;
      this.confidence = confidence;
      this.location = location;

    public String getId() {
      return id;

    public String getTitle() {
      return title;

    public Float getConfidence() {
      return confidence;

    public RectF getLocation() {
      return new RectF(location);

    public float[] getFeatures() {
      return features;

    public void setFeatures(float[] features) {
      this.features = features;

    public void setLocation(RectF location) {
      this.location = location;

    public String toString() {
      String resultString = "";
      if (id != null) {
        resultString += "[" + id + "] ";

      if (title != null) {
        resultString += title + " ";

      if (confidence != null) {
        resultString += String.format("(%.1f%%) ", confidence * 100.0f);

      if (location != null) {
        resultString += location + " ";

      return resultString.trim();

To extract the nework’s output the sample project uses run(Object input, Object output) of the class org.tensorflow.lite.Interpreter to run the model inference since the model takes only one input, and provides only one output. For our network, we need to use the function runForMultipleInputsOutputs(@NonNull Object[] inputs, @NonNull Map<Integer,Object> outputs), so we will add a Map<Integer, Object> to store the output buffers.

  Map<Integer, Object> outputBuffers = new HashMap<>();
  private final TensorBuffer outputProbabilityBuffer;
  private final TensorBuffer outputFeatureBuffer;

And initialize them in the constructor.

List<int[]> outputShapes = new ArrayList<>();
List<DataType> outputTypes = new ArrayList<>();
for(int i = 0; i < tflite.getOutputTensorCount(); i++){

outputFeatureBuffer = TensorBuffer.createFixedSize(outputShapes.get(0), outputTypes.get(0));
outputProbabilityBuffer = TensorBuffer.createFixedSize(outputShapes.get(1), outputTypes.get(1));
outputBuffers.put(0, outputFeatureBuffer.getBuffer().rewind());
outputBuffers.put(1, outputProbabilityBuffer.getBuffer().rewind());

In the recognize_image function, we simply replace tflite.run with tflite.runForMultipleInputsOutputs.

Object[] inputs = {inputImageBuffer.getBuffer()};
tflite.runForMultipleInputsOutputs(inputs, outputBuffers);

We can extract the AvgPool features like the following:

float[] features = outputFeatureBuffer.getFloatArray();

The recognize_image returns the top-k classes with the highest probability identified by the network by calling the function getTopKProbability. We will pass the AvgPool features to this function as an input and add them to the Recognition with the highest certainty.

  /** Gets the top-k results. */
  private static List<Recognition> getTopKProbability(Map<String, Float> labelProb, float[] features) {
    // Find the best classifications.
    PriorityQueue<Recognition> pq =
            new PriorityQueue<>(
                    new Comparator<Recognition>() {
                      public int compare(Recognition lhs, Recognition rhs) {
                        // Intentionally reversed to put high confidence at the head of the queue.
                        return Float.compare(rhs.getConfidence(), lhs.getConfidence());

    for (Map.Entry<String, Float> entry : labelProb.entrySet()) {
      pq.add(new Recognition("" + entry.getKey(), entry.getKey(), entry.getValue(), null));

    final ArrayList<Recognition> recognitions = new ArrayList<>();
    int recognitionsSize = Math.min(pq.size(), MAX_RESULTS);
    for (int i = 0; i < recognitionsSize; ++i) {
      Recognition toAdd = pq.poll();
      if(i == 0)
    return recognitions;

Please find the complete project forked and modified from the original repo here.