2 min read

Introducing tf-explain, Interpretability for TensorFlow 2.0

Rédigé par Raphaël Meudec

Raphaël Meudec

A Tensorflow 2.0 library for deep learning model interpretability.

Grad CAM Method (Original Photo by Kelly Lund on Unsplash)

Understanding deep networks is crucial for AI adoption. tf-explain offers interpretability methods to gain insight on your network.

The library is adapted to the Tensorflow 2.0 workflow, using tf.keras API as possible. It provides:

  • Heatmaps Visualizations & Gradients Analysis
  • Both off-training and tf.keras.Callback Usages
  • Tensorboard Integration

Disclaimer: This library is not an official Google product, although it is built for Tensorflow 2.0 models.

Why interpretability

The main challenge when working with deep neural networks is to understand the behavior of trained networks. It is crucial both to the scientist to debug and improve current model, and to the users to help them to trust the method. As a human, it is difficult to get feedback from a neural network. Interpretability has emerged in research to help tackling this:

  • Analysis of decisions over a validation set helps identify issues with the network (for instance bias in dataset, mislabeled data)
  • Heatmap visualizations are often appreciated by non-initiated users

However, those methods are not well-integrated in deep learning workflow as of today.

The solution: off the shelf analysis tools for your tf.keras models

tf-explain implements interpretability methods as Tensorflow 2.0 callbacks to ease neural networks’ understanding.

The library was built to offer a comprehensive list of interpretability methods, directly usable in your Tensorflow workflow:

  • Tensorflow 2.0 compatibility
  • Unified interface between methods
  • Support for Training Integration (callbacks, Tensorboard)

Built for Tensorflow 2.0

tf-explain respects the new TF2.0 API, and is primarily based on tf.keras when possible. It benefits from the @tf.function decorator which helps to keep support for both eager and graph mode. This allows keeping most algorithms computation time negligible compared to full training.

Methods from research

Algorithms implemented in tf-explain directly come from research. As of today, implemented methods are:

From Left to Right: Input Image, Activations Visualizations, Occlusion Sensitivity, Grad CAM, SmoothGrad on VGG16

More methods will be integrated in the library in the coming weeks.

A unique entry point

tf-explain offers a unique entry point for all its algorithms. Each method implements .explain(validation_data, model, *args), with different arguments depending on its needs. For a full description of each method args, you can refer to the documentation.

Close to your Tensorflow workflow

Tensorboard Integration

A key objective for tf-explain is to support on-training monitoring. Observing if methods’ output is reasonable during the first 3 epochs can save you an hours-long failed training. Therefore, each implemented algorithm has its corresponding tf.keras.Callback, and offers a Tensorboard integration.

Using Tensorboard seemed obvious to help the user concentrate all the metrics and information into a single dashboard. Disk storage is also available.

TF-explain is live!

Since last week, an alpha release of tf-explain is available on PyPi. Install it with pip install tf-explain.

We would also love to hear about how you are using tf-explain for your project. Drop us a line at tf-explain@sicara.com or ping me on Twitter.

Thanks to Flavian Hautbois, Antoine Toubhans, and Jeremy Joslove.

Cet article à été écrit par

Raphaël Meudec

Raphaël Meudec

Suivre toutes nos actualités

Data migration: Thinking about using AWS Data Pipeline? Think twice

4 min read

Machine learning metrics are as essential as your model

4 min read

Fundamentals of NLP with multi-choice question generation

6 min read