Measures for Interpretability of Class Activation Maps

[ ]:
!pip install -U git+https://github.com/UN-GCPDS/python-gcpds.image_segmentation.git
!pip install -U git+https://github.com/UN-GCPDS/tf-keras-vis.git
[2]:
import tensorflow as tf

from gcpds.image_segmentation.models import unet_baseline
from gcpds.image_segmentation.datasets.segmentation import InfraredThermalFeet
from gcpds.image_segmentation.losses import DiceCoefficient
[3]:
dataset = InfraredThermalFeet()
train, *_ = dataset()

train = train.map(lambda img,mask,id_img: (img,mask),
                                    num_parallel_calls=tf.data.AUTOTUNE)

shape = 256, 256
train = train.map(lambda img,mask: (tf.image.resize(img,shape), tf.image.resize(mask,shape)),
                                    num_parallel_calls=tf.data.AUTOTUNE)

train = train.batch(32)
Downloading...
From: https://drive.google.com/uc?id=1HZa4pVwlIXCrRIidflB158kmtYGW23Qe&confirm=t
To: /usr/local/lib/python3.10/dist-packages/gcpds/image_segmentation/datasets/segmentation/Datasets/InfraredThermalFeet/InfraredThermalFeet.zip
100%|██████████| 10.6M/10.6M [00:00<00:00, 29.6MB/s]
 Number of images for Partition 1: 111
 Number of images for Partition 2: 9
 Number of images for Partition 3: 46
[4]:
model = unet_baseline(input_shape=(256,256,1),out_channels=3)

model.compile(loss=DiceCoefficient(), optimizer=tf.keras.optimizers.Adam())
model.fit(train, epochs=100)
Epoch 1/100
4/4 [==============================] - 37s 922ms/step - loss: -0.3639
Epoch 2/100
4/4 [==============================] - 1s 249ms/step - loss: -0.4256
Epoch 3/100
4/4 [==============================] - 1s 250ms/step - loss: -0.4558
Epoch 4/100
4/4 [==============================] - 1s 250ms/step - loss: -0.4803
Epoch 5/100
4/4 [==============================] - 1s 253ms/step - loss: -0.5014
Epoch 6/100
4/4 [==============================] - 1s 234ms/step - loss: -0.5198
Epoch 7/100
4/4 [==============================] - 1s 251ms/step - loss: -0.5345
Epoch 8/100
4/4 [==============================] - 2s 383ms/step - loss: -0.5525
Epoch 9/100
4/4 [==============================] - 2s 407ms/step - loss: -0.5659
Epoch 10/100
4/4 [==============================] - 2s 261ms/step - loss: -0.5776
Epoch 11/100
4/4 [==============================] - 1s 280ms/step - loss: -0.5887
Epoch 12/100
4/4 [==============================] - 1s 253ms/step - loss: -0.5981
Epoch 13/100
4/4 [==============================] - 1s 252ms/step - loss: -0.6084
Epoch 14/100
4/4 [==============================] - 1s 260ms/step - loss: -0.6149
Epoch 15/100
4/4 [==============================] - 1s 262ms/step - loss: -0.6263
Epoch 16/100
4/4 [==============================] - 1s 260ms/step - loss: -0.6350
Epoch 17/100
4/4 [==============================] - 1s 251ms/step - loss: -0.6461
Epoch 18/100
4/4 [==============================] - 1s 260ms/step - loss: -0.6567
Epoch 19/100
4/4 [==============================] - 2s 385ms/step - loss: -0.6644
Epoch 20/100
4/4 [==============================] - 2s 428ms/step - loss: -0.6732
Epoch 21/100
4/4 [==============================] - 1s 325ms/step - loss: -0.6824
Epoch 22/100
4/4 [==============================] - 2s 343ms/step - loss: -0.6910
Epoch 23/100
4/4 [==============================] - 1s 254ms/step - loss: -0.6994
Epoch 24/100
4/4 [==============================] - 1s 307ms/step - loss: -0.7077
Epoch 25/100
4/4 [==============================] - 3s 634ms/step - loss: -0.7151
Epoch 26/100
4/4 [==============================] - 2s 328ms/step - loss: -0.7235
Epoch 27/100
4/4 [==============================] - 2s 347ms/step - loss: -0.7312
Epoch 28/100
4/4 [==============================] - 2s 346ms/step - loss: -0.7390
Epoch 29/100
4/4 [==============================] - 1s 256ms/step - loss: -0.7468
Epoch 30/100
4/4 [==============================] - 1s 245ms/step - loss: -0.7537
Epoch 31/100
4/4 [==============================] - 1s 266ms/step - loss: -0.7609
Epoch 32/100
4/4 [==============================] - 2s 498ms/step - loss: -0.7683
Epoch 33/100
4/4 [==============================] - 2s 362ms/step - loss: -0.7738
Epoch 34/100
4/4 [==============================] - 1s 258ms/step - loss: -0.7801
Epoch 35/100
4/4 [==============================] - 1s 337ms/step - loss: -0.7877
Epoch 36/100
4/4 [==============================] - 1s 247ms/step - loss: -0.7949
Epoch 37/100
4/4 [==============================] - 1s 245ms/step - loss: -0.8018
Epoch 38/100
4/4 [==============================] - 2s 299ms/step - loss: -0.8080
Epoch 39/100
4/4 [==============================] - 2s 363ms/step - loss: -0.8145
Epoch 40/100
4/4 [==============================] - 2s 355ms/step - loss: -0.8194
Epoch 41/100
4/4 [==============================] - 1s 245ms/step - loss: -0.8248
Epoch 42/100
4/4 [==============================] - 1s 252ms/step - loss: -0.8309
Epoch 43/100
4/4 [==============================] - 1s 271ms/step - loss: -0.8361
Epoch 44/100
4/4 [==============================] - 1s 242ms/step - loss: -0.8425
Epoch 45/100
4/4 [==============================] - 1s 247ms/step - loss: -0.8475
Epoch 46/100
4/4 [==============================] - 1s 252ms/step - loss: -0.8531
Epoch 47/100
4/4 [==============================] - 1s 303ms/step - loss: -0.8574
Epoch 48/100
4/4 [==============================] - 1s 253ms/step - loss: -0.8625
Epoch 49/100
4/4 [==============================] - 1s 248ms/step - loss: -0.8672
Epoch 50/100
4/4 [==============================] - 1s 247ms/step - loss: -0.8707
Epoch 51/100
4/4 [==============================] - 1s 249ms/step - loss: -0.8760
Epoch 52/100
4/4 [==============================] - 1s 241ms/step - loss: -0.8799
Epoch 53/100
4/4 [==============================] - 2s 317ms/step - loss: -0.8842
Epoch 54/100
4/4 [==============================] - 1s 309ms/step - loss: -0.8878
Epoch 55/100
4/4 [==============================] - 2s 558ms/step - loss: -0.8911
Epoch 56/100
4/4 [==============================] - 2s 351ms/step - loss: -0.8950
Epoch 57/100
4/4 [==============================] - 1s 254ms/step - loss: -0.8981
Epoch 58/100
4/4 [==============================] - 1s 243ms/step - loss: -0.9008
Epoch 59/100
4/4 [==============================] - 1s 247ms/step - loss: -0.9032
Epoch 60/100
4/4 [==============================] - 1s 245ms/step - loss: -0.9065
Epoch 61/100
4/4 [==============================] - 1s 251ms/step - loss: -0.9093
Epoch 62/100
4/4 [==============================] - 1s 247ms/step - loss: -0.9124
Epoch 63/100
4/4 [==============================] - 1s 304ms/step - loss: -0.9142
Epoch 64/100
4/4 [==============================] - 2s 355ms/step - loss: -0.9181
Epoch 65/100
4/4 [==============================] - 1s 247ms/step - loss: -0.9194
Epoch 66/100
4/4 [==============================] - 1s 246ms/step - loss: -0.9221
Epoch 67/100
4/4 [==============================] - 1s 251ms/step - loss: -0.9247
Epoch 68/100
4/4 [==============================] - 1s 245ms/step - loss: -0.9268
Epoch 69/100
4/4 [==============================] - 1s 247ms/step - loss: -0.9281
Epoch 70/100
4/4 [==============================] - 1s 245ms/step - loss: -0.9307
Epoch 71/100
4/4 [==============================] - 1s 256ms/step - loss: -0.9332
Epoch 72/100
4/4 [==============================] - 1s 340ms/step - loss: -0.9343
Epoch 73/100
4/4 [==============================] - 2s 322ms/step - loss: -0.9361
Epoch 74/100
4/4 [==============================] - 1s 249ms/step - loss: -0.9380
Epoch 75/100
4/4 [==============================] - 1s 248ms/step - loss: -0.9402
Epoch 76/100
4/4 [==============================] - 1s 248ms/step - loss: -0.9421
Epoch 77/100
4/4 [==============================] - 1s 255ms/step - loss: -0.9438
Epoch 78/100
4/4 [==============================] - 1s 244ms/step - loss: -0.9457
Epoch 79/100
4/4 [==============================] - 1s 246ms/step - loss: -0.9467
Epoch 80/100
4/4 [==============================] - 1s 246ms/step - loss: -0.9482
Epoch 81/100
4/4 [==============================] - 1s 247ms/step - loss: -0.9497
Epoch 82/100
4/4 [==============================] - 2s 372ms/step - loss: -0.9513
Epoch 83/100
4/4 [==============================] - 1s 268ms/step - loss: -0.9524
Epoch 84/100
4/4 [==============================] - 1s 249ms/step - loss: -0.9530
Epoch 85/100
4/4 [==============================] - 1s 245ms/step - loss: -0.9539
Epoch 86/100
4/4 [==============================] - 1s 245ms/step - loss: -0.9534
Epoch 87/100
4/4 [==============================] - 1s 243ms/step - loss: -0.9543
Epoch 88/100
4/4 [==============================] - 1s 249ms/step - loss: -0.9560
Epoch 89/100
4/4 [==============================] - 1s 246ms/step - loss: -0.9566
Epoch 90/100
4/4 [==============================] - 1s 250ms/step - loss: -0.9577
Epoch 91/100
4/4 [==============================] - 1s 329ms/step - loss: -0.9586
Epoch 92/100
4/4 [==============================] - 1s 248ms/step - loss: -0.9594
Epoch 93/100
4/4 [==============================] - 1s 246ms/step - loss: -0.9608
Epoch 94/100
4/4 [==============================] - 1s 243ms/step - loss: -0.9616
Epoch 95/100
4/4 [==============================] - 1s 243ms/step - loss: -0.9622
Epoch 96/100
4/4 [==============================] - 1s 247ms/step - loss: -0.9631
Epoch 97/100
4/4 [==============================] - 1s 245ms/step - loss: -0.9638
Epoch 98/100
4/4 [==============================] - 1s 247ms/step - loss: -0.9647
Epoch 99/100
4/4 [==============================] - 1s 252ms/step - loss: -0.9651
Epoch 100/100
4/4 [==============================] - 2s 365ms/step - loss: -0.9653
[4]:
<keras.callbacks.History at 0x7f66fc108f70>
[11]:
layer_indexes = [i for i,l in enumerate(model.layers) if 'Conv' in l.name]
layer_indexes
[11]:
[2, 4, 7, 9, 12, 14, 17, 19, 22, 24, 28, 30, 34, 36, 40, 42, 46, 48, 50]
[8]:
from tf_keras_vis.gradcam import Gradcam
from tf_keras_vis.utils.model_modifiers import ReplaceToLinear
import matplotlib.pyplot as plt

from gcpds.image_segmentation.class_activation_maps import SegScore
[9]:
gradcam = Gradcam(model,
                   model_modifier=ReplaceToLinear(),
                  clone=True)

CAM Dice

\[D^{'}_r = \mathbb{E}_{l}\left\{\mathbb{E}_{n}\Biggl\{ 2 \frac{\mathbf{1}^\top \left(\tilde{\mathbf{M}}^r_n \odot \mathbf{S}^r_{nl}\right) \mathbf{1}}{ \mathbf{1}^\top \tilde{\mathbf{M}}^r_n \mathbf{1} + \mathbf{1}^\top \mathbf{S}^r_{nl} \mathbf{1}} :\forall n\in N\Biggl\}:\forall l\in L\right\}, \quad D_{r}^{'} \in [0,1]\]
[22]:
from gcpds.image_segmentation.class_activation_maps.measures.partials import cam_dice
[39]:
total_layers = len(layer_indexes)
total_samples = 111
target_class = 1 #only for class 1

cam_dice_values = np.zeros(shape=(total_layers, total_samples))

for i,l in enumerate(layer_indexes):
    for j, (img, mask) in enumerate(train.unbatch().batch(1)):
        score = SegScore(mask,target_class=target_class)
        cam = gradcam(score,
              img,
              penultimate_layer=l,
              seek_penultimate_conv_layer=False)
        cam = cam[..., None]
        roi = tf.cast(mask == target_class, tf.float32)
        cam_dice_values[i,j] = cam_dice(roi, cam)
[40]:
plt.imshow(cam_dice_values)
[40]:
<matplotlib.image.AxesImage at 0x7f667ed5b6a0>
../_images/notebooks_07-interpretability_measures_11_1.png
[42]:
cam_dice_value = cam_dice_values.mean(axis=1).mean()
cam_dice_value
[42]:
0.26159253942596367

CAM-based Cumulative Relevance

\[\rho_r = \mathbb{E}_{l}\left\{\mathbb{E}_{n}\Biggl\{ \frac{ \mathbf{1}^\top (\tilde{\mathbf{M}}^c_n \odot \mathbf{S}^c_{nl}) \mathbf{1}}{\mathbf{1}^\top \mathbf{S}^c_{nl} \mathbf{1}} : \forall n\in N\Biggl\}\forall l\in L\right\} , \quad \rho_c\in [0,1]\]
[47]:
from gcpds.image_segmentation.class_activation_maps.measures.partials import cam_cumulative_relevance
[48]:
total_layers = len(layer_indexes)
total_samples = 111
target_class = 1 #only for class 1

cam_cumulative_values = np.zeros(shape=(total_layers, total_samples))

for i,l in enumerate(layer_indexes):
    for j, (img, mask) in enumerate(train.unbatch().batch(1)):
        score = SegScore(mask,target_class=target_class)
        cam = gradcam(score,
              img,
              penultimate_layer=l,
              seek_penultimate_conv_layer=False)
        cam = cam[..., None]
        roi = tf.cast(mask == target_class, tf.float32)
        cam_cumulative_values[i,j] = cam_cumulative_relevance(roi, cam)
[49]:
plt.imshow(cam_cumulative_values)
[49]:
<matplotlib.image.AxesImage at 0x7f667e2bd4e0>
../_images/notebooks_07-interpretability_measures_16_1.png
[50]:
cam_cumulative_values = cam_cumulative_values.mean(axis=1).mean()
cam_cumulative_values
[50]:
0.5578747539461217

Mask-based Cumulative Relevance

\[\varrho^{'}_{rl} = \mathbb{E}_{n}\Biggl\{ \frac{ \mathbf{1}^\top (\tilde{\mathbf{M}}^r_n \odot \mathbf{S}^r_{nl}) \mathbf{1}}{\mathbf{1}^\top \tilde{\mathbf{M}}^r_n \mathbf{1}} : \forall n\in N\Biggl\}, \varrho_{rl}\in\mathbb{R}^+\]
\[\varrho_r = \mathbb{E}_{l}\left\{ \frac{{\varrho^{'}}_{rl}}{{\max\limits_{c \in \{0,1\}} \varrho^{'}_{cl}}} \forall l\in L\right\}, \quad \varrho^{'}_r\in [0,1]\]
[51]:
from gcpds.image_segmentation.class_activation_maps.measures.partials import masked_cumulative_relevance
[53]:
total_layers = len(layer_indexes)
total_samples = 111
target_class = 1 #only for class 1

masked_cumulative_values_class_one = np.zeros(shape=(total_layers, total_samples))

for i,l in enumerate(layer_indexes):
    for j, (img, mask) in enumerate(train.unbatch().batch(1)):
        score = SegScore(mask,target_class=target_class)
        cam = gradcam(score,
              img,
              penultimate_layer=l,
              seek_penultimate_conv_layer=False)
        cam = cam[..., None]
        roi = tf.cast(mask == target_class, tf.float32)
        masked_cumulative_values_class_one[i,j] = masked_cumulative_relevance(roi, cam)
[57]:
plt.imshow(masked_cumulative_values_class_one)
[57]:
<matplotlib.image.AxesImage at 0x7f6667ed4550>
../_images/notebooks_07-interpretability_measures_21_1.png
[56]:
target_class = 0 #only for class 1

masked_cumulative_values_class_zero = np.zeros(shape=(total_layers, total_samples))

for i,l in enumerate(layer_indexes):
    for j, (img, mask) in enumerate(train.unbatch().batch(1)):
        score = SegScore(mask,target_class=target_class)
        cam = gradcam(score,
              img,
              penultimate_layer=l,
              seek_penultimate_conv_layer=False)
        cam = cam[..., None]
        roi = tf.cast(mask == target_class, tf.float32)
        masked_cumulative_values_class_zero[i,j] = masked_cumulative_relevance(roi, cam)
[58]:
plt.imshow(masked_cumulative_values_class_zero)
[58]:
<matplotlib.image.AxesImage at 0x7f6667dbfe20>
../_images/notebooks_07-interpretability_measures_23_1.png
[65]:
masked_cumulative_mean_samples_class_one = masked_cumulative_values_class_one.mean(axis=1, keepdims=True)
masked_cumulative_mean_samples_class_zero = masked_cumulative_values_class_zero.mean(axis=1, keepdims=True)

masked_cumulative_mean_samples = np.concatenate([masked_cumulative_mean_samples_class_one,
                                                 masked_cumulative_mean_samples_class_zero],
                                                axis=1)
[66]:
masked_cumulative_mean_samples.shape
[66]:
(19, 2)
[73]:
masked_cumulative_class_one = masked_cumulative_mean_samples[:,0]/masked_cumulative_mean_samples.max(axis=1)
masked_cumulative_class_zero = masked_cumulative_mean_samples[:,1]/masked_cumulative_mean_samples.max(axis=1)
[75]:
masked_cumulative_class_one = masked_cumulative_class_one.mean()
masked_cumulative_class_zero = masked_cumulative_class_zero.mean()
[76]:
masked_cumulative_class_one, masked_cumulative_class_zero
[76]:
(0.8432377253855776, 0.5880907436847627)