Optimizing Models with Quantization-Aware Training in Keras - Part II

Optimizing Models with Quantization-Aware Training in Keras - Part II

Created
Jan 2, 2026 5:35 PM
Tags

This article is Part-II of the 'Quantization in Keras' Series. In part I we saw how to perform Post Quantization Training in Keras for the Facial Keypoints Detection Dataset.

Introduction

We defined "quantization" in the first part of this series, but, by way of reminder: when we talk quantization in this piece, we're specifically referring the process whereby you reduce the size of a model without significantly sacrificing accuracy. This is done frequently so these models can be deployed on edge devices like smartphones or microcontrollers.

In part one, we saw that Post Training Quantization helped us to reduce the model size by 10 times and improved the model accuracy as well. In part two, we'll look at the impact of Quantization-Aware Training on our model.

For Quantization-Aware Training, we fine-tune the model and obtain the Quantized Model.

The first part of the article remains the largely the same: we'll train a Baseline ResNet-50 Model on a dataset of images of people's faces. If you have read part one, you can directly jump to the section called Quantization-Aware Training.

Let's quantize:

Our Data and Our Goal

We’ll be using this facial keypoints dataset from Kaggle for our tutorial and experiments today. The goal of that competition, simply, was to detect keypoints on pictures of people’s faces.

Essentially, a keypoint helps a model understand and orient a person. They are, well, key points for a model, generally things like the tip of the nose, the center of the eye, or the corner of a mouth. They are not the center of the cheek, a forehead, etc. Keypoints have x and y pixel coordinates for pixels for that location. In this dataset we are given the x and y coordinates for 15 such keypoints for grayscale images so we have a total of 30 inputs for the model per picture. The input image is given in the last field of the data files, and consists of a list of pixels (ordered by row), as integers in (0,255). The images are 96x96 pixels.

Again, our goal here is pretty basic: to detect the locations of the keypoints and to use quantization to reduce the compute size of our model.

Preliminaries

We start by importing the necessary libraries. We will see how we use them as we move ahead.

Global Configuration

Its always a nice practice to have a separate configuration file or a class which incorporates the parameters as shown below. We have four files in the dataset:

  • Training File - List of training 7049 images. Each row contains the (x,y) coordinates for 15 keypoints, and image data as row-ordered list of pixels.
  • Test File - List of 1783 test images. Each row contains ImageId and image data as row-ordered list of pixels
  • ID Lookup Table - The keypoint which is supposed to be predicted
  • Sample Submission - The final submission file containing the predicted locations of the keypoints.

Load Data

Now that we have set our configuration, we load our dataset using the load_data() funcion.

def load_data():

	"""
  Function to load the trainand test dataset
	"""

  X_train= pd.read_csv(config.TRAIN_FILE_PATH)
  X_test_original= pd.read_csv(config.TEST_FILE_PATH)

	return X_train, X_test_original

Check Missing Values

As there are many missing values, from inspection of data we can conclude that some images have 15 keypoints whereas others have only 4.

We can take two approaches, either to fill/remove missing values and train one model on the entire dataset or use two models, one for data with 4 keypoints and other for data with 15 keypoints and then combine their predictions.

For simplicity, here we are dropping the missing value samples as we want to train only 1 model on the entire dataset and report changes observed via optimization.

def remove_missing_values():

	"""
  Function to drop the sampleswith missing values
	"""

  X_train, X_test_original= load_data()

  X_train= X_train.dropna()
  y_train= X_train.drop(['Image'], axis=1)

	return X_train, y_train, X_test_original

Obtain Images

The images are represented in terms of space separated pixels in the training data, we need to convert them into the correct format.

Let us view some images with their keypoints to understand the problem better.

# View Images with their keypoints

X_train, y_train, _= get_images()

fig= plt.figure(figsize=(9,9))

for i inrange(9):
    ax= fig.add_subplot(3,3, i+1)
    plt.imshow(X_train[i])

for i inrange(1,31,2):
        plt.plot(y_train.loc[0][i-1], y_train.loc[0][i],'ro')

plt.show()
image
Training Images with their Keypoints

As you can clearly see, each image has 15 keypoints represented by the dots for different locations of their face. The location (x and y coordinates of pixel values) is what we are trying to predict.

Prepare the Data

We further process the data more to convert the data to numpy arrays, flatten and reshape them to the desired format.

To prepare our validation dataset, we will simply use the train_test_split function offered by scikit-learn.

After we have prepared the training, validation and test datasets, we will make the images as 3 channels (originally they are 1 channel only) to suit our requirements which is achieved using the get_correct_dimensions() function.

ResNet50 Baseline Model

Till here we have preprocessed our dataset and have obtained the datasets in the correct format. The data is now ready to be trained.

For training the model, we will implement Transfer Learning using a pretrained ResNet50 model. You can try out larger models like EfficientNets too, but for our use case a single ResNet model is enough to achieve a decent score.

Before training a model we need to define certain callbacks. In this case we define 4 callbacks -

  1. TensorBoard to log
  2. Early Stopping to prevent the model from overfitting
  3. Checkpoint Callback to save the model callback
  4. WandB Callback to log everything to our wandb dashboard

Then we define our ResNet model with pretrained weights. We will freeze the pre-trained layers and add custom layers at the end.

Finally we define the entire architecture consisting of 3 simple data augmentations, followed by the ResNet model, a flatten layer and a dense layer.

Now we define the optimizer which in our case is Adam and compile the model for training. Notice that we are using Root Mean Squared Error as the metric here.

def make_model():

	"""
  Function to load model, define optimizer and compile the model
	"""

  model= setup_model()
	# Define Optimizer
  optimizer= keras.optimizers.Adam(learning_rate=config.LEARNING_RATE)

	# Compile the model
  model.compile(
  optimizer=optimizer,
  loss='mean_squared_error',
  metrics=tf.keras.metrics.RootMeanSquaredError()
	)

	return model

Thus, we now load our datasets, obtain the callbacks and set the model training function.

Run Model

We initialize wandb to record our run and start training our model.

run= wandb.init(project='Quantization Aware Training')

history= run_model(config.BASELINE_CHECKPOINT_PATH)

wandb.finish()
image

From the above charts, it is clear that with increasing number of epochs the training as well as the validation loss both decrease, thus there is no overfitting.

Model Evaluation and Size

We will create two functions namely evaluate_model() to obtain the metric (root mean squared error) on validation dataset and get_model_size() to obtain the size of the model.

# Evaluate Model
evaluate_model(config.BASELINE_CHECKPOINT_PATH)

# Model Size
get_model_size(config.BASELINE_CHECKPOINT_PATH)

Baseline Model Conclusion:

These are the results obtained from the baseline model we trained.

Model Description -

  • ResNet50 with Pretrained Weights along with Data Augmentations

Result -

  • Root Mean Squared Error = 3.05
  • Training Time = 776 seconds
  • Model Size = 228 MB

We will refer back to this conclusion after we quantize the model.

Quantization-Aware Training

Quantization-Aware Training emulates inference-time quantization, creating a model that downstream tools will use to produce actually quantized models. The quantized models use lower-precision (e.g. 8-bit instead of 32-bit float), leading to benefits during deployment.

Define the Model

There are multiple ways to implement the Quantization-Aware Training method.

You can choose to Quantize the entire model i.e., all the layers or Quantize only selected layers of the model.

The question is, if Quantization reduces the model size, won't it be better to just quantize the entire model, then why do we have the option to quantize some layers only?

The reason is simple, Quantizing a model can have a negative effect on accuracy. Thus, we can selectively quantize layers of a model to explore the trade-off between accuracy, speed, and model size.

Quantization-Aware Training is a method of fine tuning. Thus, we need to create q Quantize Aware model and train it again. In case of Post Training Quantization, we did not need to retrain the model.

Step I is loading the Baseline ResNet-50 model. Here's the code:

# Load model with pretrained weights
model= setup_pretrained_model(config.BASELINE_CHECKPOINT_PATH)

In Step II, we're going to quantize only some layers. We select that we will quantize only the dense layers of the model. This needs to be specified to TensorFlow.

So we define a helper function apply_quantization_to_dense() which specifies that only the dense layers need to be quantized. Then we annotate the Dense layers using the quantize_annotate_layer() function provided by TensorFlow.

def apply_quantization_to_dense(layer):
	if is instance(layer, tf.keras.layers.Dense):
		return tfmot.quantization.keras.quantize_annotate_layer(layer)
	return layer

Next, in Step III before applying quantization we need to clone the model. To implement this we use tf.keras.models.clone_model to apply apply_quantization_to_dense to the layers of the model.

annotated_model= tf.keras.models.clone_model(
    model,
    clone_function=apply_quantization_to_dense,
)

In Step IV, we build the annotated_model i.e., specify the input size to the model. Now finally the Dense layers are annotated and quantize_apply actually makes the model quantization-aware.

# Build Model
annotated_model.build((None,96,96,3))

# Now that the Dense layers are annotated,
# `quantize_apply` actually makes the model quantization aware.
quant_aware_model= tfmot.quantization.keras.quantize_apply(annotated_model)

In Step VI, similar to training any regular Deep Learning model, we specify the optimizer and compile the model for training.

# Define Optimizer
optimizer= keras.optimizers.Adam(learning_rate=0.03)

# `quantize_model` requires a recompile.
quant_aware_model.compile(
    optimizer=optimizer,
    loss='mean_squared_error',
    metrics=tf.keras.metrics.RootMeanSquaredError()
)

Moving ahead to Step VII, we load the required datasets and define the Early Stopping Callback along with others. Then we train the model and use wandb to log everything.

image

From the above charts we can conclude that the model accuracy affected and it has slightly increased as compared to the Baseline Model.

But wait, we are not done yet!

After this, you have an actually quantized model with int8 weights and uint8 activations.

# Convert to TensorFlow Lite Model with optimization
converter= tf.lite.TFLiteConverter.from_keras_model(quant_aware_model)
converter.optimizations=[tf.lite.Optimize.DEFAULT]
quantized_tflite_model= converter.convert()

Now we save the quantized model and write the file in tflite format.

At this stage we have a Quantization-Aware Model. We will now evaluate it against the test dataset.

# Root Mean Squared Error
rmse= evaluate_quant_model(interpreter)

print(f"RMSE for Quantization Aware Training: {rmse}")

# Model Size
get_model_size(config.QUANTIZATION_AWARE_TRAINING_PATH, checkpoint=False)

Quantization-Aware Training Model Conclusion:

Model Description -

  • ResNet50 with Pretrained Weights along with Data Augmentations
  • Quantization-Aware Training on Dense Layers

Results -

  • Root Mean Squared Error = 2.86 (Considerably better than baseline model but similar to Post Training Quantization)
  • Training Time = 151 seconds (Less than baseline model)
  • Model Size = 97 MB (Considerably smaller than baseline larger than post training quantization)

TFLite Inference

In the final step, we will perform inference on the original Test Dataset provided in the competition. The steps are similar as the last section.

# Load the TFLite model and allocate tensors.
interpreter= tf.lite.Interpreter(model_path=config.QUANTIZATION_AWARE_TRAINING_PATH)
interpreter.allocate_tensors()

predictions= tflite_inference(interpreter)

After we have obtained the predictions, we create a submission file for the competition, which contains the Row Id along with the location of the keypoint.

We will plot a couple of images from the test set to see how well our model performs.

image

The Quantization-Aware Model is also giving us pretty decent results and is able to identify the different keypoints in the images.

Ideas for Improving Accuracy

Here are some more ideas which you can try out to improve your accuracy.

  • Use KFold Training instead of train_test_split
  • Experiment with different strategies to fill nan values instead of removing them
  • Use external datasets to increase size of training data
  • Use larger models such as Efficientnets
  • Use 2 models - one for images having 4 keypoints and another for images having 15 keypoints
  • Use a face detector and face recognizer to obtain bounding box that look most like a face using architectures such as YOLO as there are many images where the face is very small
  • Face Aligner
  • Experiment with more data augmentations such as reflection, zooming, shift, bluring, noise and elastic transformations.
  • Train multiple models and create their ensembles as final model.

Conclusion

We conclude that both - Post Training Quantization and Quantization-Aware Training helped us increase the model accuracy slightly.

However, with Post Training Quantization, we were able to reduce model size from 228 MB to 25 MB whereas after Quantization-Aware Training model size was 97 MB.

Thus, we can infer that for this use case, Post Training Quantization is the best performing method, in terms of time, accuracy and size.

Feel free to implement these methods on different use cases and compare the results.

The entire code is available on GitHub in the repository facial-keypoints-detection.

If you still face any difficulties reach out to me on LinkedIn or Twitter, my messages are open :)