Class Activation MapsΒΆ
[ ]:
import tensorflow as tf
from gcpds.image_segmentation.models import unet_baseline
from gcpds.image_segmentation.datasets.segmentation import OxfordIiitPet
from gcpds.image_segmentation.losses import DiceCoefficient
[ ]:
dataset = OxfordIiitPet()
train, *_ = dataset()
train = train.map(lambda img,mask,label,id_img: (img,mask),
num_parallel_calls=tf.data.AUTOTUNE)
shape = 256, 256
train = train.map(lambda img,mask: (tf.image.resize(img,shape), tf.image.resize(mask,shape)),
num_parallel_calls=tf.data.AUTOTUNE)
train = train.batch(32)
[5]:
model = unet_baseline(input_shape=(256,256,3),out_channels=3)
model.compile(loss=DiceCoefficient(), optimizer=tf.keras.optimizers.Adam())
model.fit(train, epochs=100)
Epoch 1/100
81/81 [==============================] - 35s 245ms/step - loss: -0.4946
Epoch 2/100
81/81 [==============================] - 19s 230ms/step - loss: -0.5851
Epoch 3/100
81/81 [==============================] - 19s 233ms/step - loss: -0.6406
Epoch 4/100
81/81 [==============================] - 19s 235ms/step - loss: -0.6810
Epoch 5/100
81/81 [==============================] - 19s 231ms/step - loss: -0.7145
Epoch 6/100
81/81 [==============================] - 19s 230ms/step - loss: -0.7347
Epoch 7/100
81/81 [==============================] - 19s 228ms/step - loss: -0.7506
Epoch 8/100
81/81 [==============================] - 19s 226ms/step - loss: -0.7613
Epoch 9/100
81/81 [==============================] - 18s 221ms/step - loss: -0.7724
Epoch 10/100
81/81 [==============================] - 18s 224ms/step - loss: -0.7828
Epoch 11/100
81/81 [==============================] - 19s 228ms/step - loss: -0.7886
Epoch 12/100
81/81 [==============================] - 27s 331ms/step - loss: -0.7978
Epoch 13/100
81/81 [==============================] - 19s 226ms/step - loss: -0.8053
Epoch 14/100
81/81 [==============================] - 19s 227ms/step - loss: -0.8095
Epoch 15/100
81/81 [==============================] - 18s 224ms/step - loss: -0.8146
Epoch 16/100
81/81 [==============================] - 19s 231ms/step - loss: -0.8168
Epoch 17/100
81/81 [==============================] - 19s 227ms/step - loss: -0.8173
Epoch 18/100
81/81 [==============================] - 19s 228ms/step - loss: -0.8239
Epoch 19/100
81/81 [==============================] - 19s 228ms/step - loss: -0.8327
Epoch 20/100
81/81 [==============================] - 18s 221ms/step - loss: -0.8374
Epoch 21/100
81/81 [==============================] - 18s 223ms/step - loss: -0.8430
Epoch 22/100
81/81 [==============================] - 19s 236ms/step - loss: -0.8465
Epoch 23/100
81/81 [==============================] - 18s 225ms/step - loss: -0.8503
Epoch 24/100
81/81 [==============================] - 19s 228ms/step - loss: -0.8514
Epoch 25/100
81/81 [==============================] - 18s 223ms/step - loss: -0.8533
Epoch 26/100
81/81 [==============================] - 18s 221ms/step - loss: -0.8554
Epoch 27/100
81/81 [==============================] - 18s 222ms/step - loss: -0.8583
Epoch 28/100
81/81 [==============================] - 18s 224ms/step - loss: -0.8613
Epoch 29/100
81/81 [==============================] - 19s 228ms/step - loss: -0.8639
Epoch 30/100
81/81 [==============================] - 18s 226ms/step - loss: -0.8651
Epoch 31/100
81/81 [==============================] - 18s 225ms/step - loss: -0.8669
Epoch 32/100
81/81 [==============================] - 19s 236ms/step - loss: -0.8691
Epoch 33/100
81/81 [==============================] - 19s 228ms/step - loss: -0.8695
Epoch 34/100
81/81 [==============================] - 18s 222ms/step - loss: -0.8729
Epoch 35/100
81/81 [==============================] - 18s 221ms/step - loss: -0.8757
Epoch 36/100
81/81 [==============================] - 19s 226ms/step - loss: -0.8768
Epoch 37/100
81/81 [==============================] - 19s 226ms/step - loss: -0.8770
Epoch 38/100
81/81 [==============================] - 19s 227ms/step - loss: -0.8777
Epoch 39/100
81/81 [==============================] - 18s 223ms/step - loss: -0.8791
Epoch 40/100
81/81 [==============================] - 18s 224ms/step - loss: -0.8812
Epoch 41/100
81/81 [==============================] - 18s 226ms/step - loss: -0.8818
Epoch 42/100
81/81 [==============================] - 19s 232ms/step - loss: -0.8835
Epoch 43/100
81/81 [==============================] - 18s 226ms/step - loss: -0.8851
Epoch 44/100
81/81 [==============================] - 18s 225ms/step - loss: -0.8843
Epoch 45/100
81/81 [==============================] - 18s 223ms/step - loss: -0.8862
Epoch 46/100
81/81 [==============================] - 18s 222ms/step - loss: -0.8882
Epoch 47/100
81/81 [==============================] - 18s 225ms/step - loss: -0.8910
Epoch 48/100
81/81 [==============================] - 19s 231ms/step - loss: -0.8930
Epoch 49/100
81/81 [==============================] - 19s 232ms/step - loss: -0.8952
Epoch 50/100
81/81 [==============================] - 18s 224ms/step - loss: -0.8972
Epoch 51/100
81/81 [==============================] - 19s 234ms/step - loss: -0.8992
Epoch 52/100
81/81 [==============================] - 18s 221ms/step - loss: -0.9007
Epoch 53/100
81/81 [==============================] - 18s 223ms/step - loss: -0.9011
Epoch 54/100
81/81 [==============================] - 19s 226ms/step - loss: -0.9025
Epoch 55/100
81/81 [==============================] - 18s 225ms/step - loss: -0.9040
Epoch 56/100
81/81 [==============================] - 18s 225ms/step - loss: -0.9044
Epoch 57/100
81/81 [==============================] - 18s 221ms/step - loss: -0.9045
Epoch 58/100
81/81 [==============================] - 19s 229ms/step - loss: -0.9048
Epoch 59/100
81/81 [==============================] - 18s 225ms/step - loss: -0.9057
Epoch 60/100
81/81 [==============================] - 18s 224ms/step - loss: -0.9071
Epoch 61/100
81/81 [==============================] - 19s 238ms/step - loss: -0.9068
Epoch 62/100
81/81 [==============================] - 18s 223ms/step - loss: -0.9078
Epoch 63/100
81/81 [==============================] - 18s 225ms/step - loss: -0.9069
Epoch 64/100
81/81 [==============================] - 18s 221ms/step - loss: -0.9063
Epoch 65/100
81/81 [==============================] - 19s 227ms/step - loss: -0.9063
Epoch 66/100
81/81 [==============================] - 18s 223ms/step - loss: -0.9067
Epoch 67/100
81/81 [==============================] - 18s 223ms/step - loss: -0.9080
Epoch 68/100
81/81 [==============================] - 18s 223ms/step - loss: -0.9078
Epoch 69/100
81/81 [==============================] - 18s 225ms/step - loss: -0.9089
Epoch 70/100
81/81 [==============================] - 19s 232ms/step - loss: -0.9113
Epoch 71/100
81/81 [==============================] - 19s 226ms/step - loss: -0.9117
Epoch 72/100
81/81 [==============================] - 18s 222ms/step - loss: -0.9119
Epoch 73/100
81/81 [==============================] - 18s 226ms/step - loss: -0.9121
Epoch 74/100
81/81 [==============================] - 19s 229ms/step - loss: -0.9111
Epoch 75/100
81/81 [==============================] - 18s 222ms/step - loss: -0.9118
Epoch 76/100
81/81 [==============================] - 18s 226ms/step - loss: -0.9122
Epoch 77/100
81/81 [==============================] - 18s 220ms/step - loss: -0.9132
Epoch 78/100
81/81 [==============================] - 18s 225ms/step - loss: -0.9126
Epoch 79/100
81/81 [==============================] - 18s 222ms/step - loss: -0.9131
Epoch 80/100
81/81 [==============================] - 19s 236ms/step - loss: -0.9136
Epoch 81/100
81/81 [==============================] - 18s 225ms/step - loss: -0.9141
Epoch 82/100
81/81 [==============================] - 18s 223ms/step - loss: -0.9144
Epoch 83/100
81/81 [==============================] - 18s 226ms/step - loss: -0.9152
Epoch 84/100
81/81 [==============================] - 18s 223ms/step - loss: -0.9146
Epoch 85/100
81/81 [==============================] - 18s 224ms/step - loss: -0.9168
Epoch 86/100
81/81 [==============================] - 19s 227ms/step - loss: -0.9180
Epoch 87/100
81/81 [==============================] - 18s 225ms/step - loss: -0.9191
Epoch 88/100
81/81 [==============================] - 18s 224ms/step - loss: -0.9192
Epoch 89/100
81/81 [==============================] - 18s 224ms/step - loss: -0.9190
Epoch 90/100
81/81 [==============================] - 19s 235ms/step - loss: -0.9189
Epoch 91/100
81/81 [==============================] - 19s 227ms/step - loss: -0.9191
Epoch 92/100
81/81 [==============================] - 18s 221ms/step - loss: -0.9184
Epoch 93/100
81/81 [==============================] - 18s 225ms/step - loss: -0.9170
Epoch 94/100
81/81 [==============================] - 18s 222ms/step - loss: -0.9172
Epoch 95/100
81/81 [==============================] - 19s 227ms/step - loss: -0.9177
Epoch 96/100
81/81 [==============================] - 19s 227ms/step - loss: -0.9179
Epoch 97/100
81/81 [==============================] - 18s 222ms/step - loss: -0.9184
Epoch 98/100
81/81 [==============================] - 18s 222ms/step - loss: -0.9193
Epoch 99/100
81/81 [==============================] - 19s 235ms/step - loss: -0.9216
Epoch 100/100
81/81 [==============================] - 18s 221ms/step - loss: -0.9229
[5]:
<keras.callbacks.History at 0x7f9f31093700>
[8]:
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
from tf_keras_vis.gradcam import Gradcam
from tf_keras_vis.utils.model_modifiers import ReplaceToLinear
from gcpds.image_segmentation.class_activation_maps import SegScore
[9]:
def plot_cams(cam,data,nrows=1, ncols=5,figsize=(20, 10)):
f, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
ax = ax.ravel()
for i, title in enumerate(np.arange(nrows*ncols)):
heatmap = np.uint8(cm.jet(cam[i])[..., :3] * 255)
ax[i].imshow(data[i])
ax[i].imshow(heatmap, cmap='jet', alpha=0.3)
ax[i].axis('off')
plt.tight_layout()
plt.show()
[10]:
images, masks = train.unbatch().batch(5).take(1).get_single_element()
[17]:
gradcam = Gradcam(model,
model_modifier=ReplaceToLinear(),
clone=True)
score = SegScore(masks,target_class=0)
cam = gradcam(score,
images,
penultimate_layer='Conv91',
seek_penultimate_conv_layer=False)
plot_cams(cam,images)
[18]:
score = SegScore(masks,target_class=1)
cam = gradcam(score,
images,
penultimate_layer='Conv91',
seek_penultimate_conv_layer=False)
plot_cams(cam,images)
[19]:
score = SegScore(masks,target_class=2)
cam = gradcam(score,
images,
penultimate_layer='Conv91',
seek_penultimate_conv_layer=False)
plot_cams(cam,images)