Source code for image_segmentation.models.baseline_unet

"""
https://github.com/cralji/RFF-Nerve-UTP/blob/main/UNET-Nerve-UTP.ipynb
"""

from functools import partial

import tensorflow as tf
from tensorflow.keras import Model, layers, regularizers


DefaultConv2D = partial(layers.Conv2D,
                        kernel_size=3, activation='relu', padding="same")

DefaultPooling = partial(layers.MaxPool2D,
                        pool_size=2)

upsample = partial(layers.UpSampling2D, (2,2))

[docs]def kernel_initializer(seed): return tf.keras.initializers.GlorotUniform(seed=seed)
[docs]def unet_baseline(input_shape=(128,128,3), name='UNET', out_channels=1, out_ActFunction='sigmoid'): # Encoder input = layers.Input(shape=input_shape) x = layers.BatchNormalization(name='Batch00')(input) x = DefaultConv2D(8,kernel_initializer=kernel_initializer(34),name='Conv10')(x) x = layers.BatchNormalization(name='Batch10')(x) x = level_1 = DefaultConv2D(8,kernel_initializer=kernel_initializer(4),name='Conv11')(x) x = layers.BatchNormalization(name='Batch11')(x) x = DefaultPooling(name='Pool10')(x) # 128x128 -> 64x64 x = DefaultConv2D(16,kernel_initializer=kernel_initializer(56),name='Conv20')(x) x = layers.BatchNormalization(name='Batch20')(x) x = level_2 = DefaultConv2D(16,kernel_initializer=kernel_initializer(32),name='Conv21')(x) x = layers.BatchNormalization(name='Batch22')(x) x = DefaultPooling(name='Pool20')(x) # 64x64 -> 32x32 x = DefaultConv2D(32,kernel_initializer=kernel_initializer(87),name='Conv30')(x) x = layers.BatchNormalization(name='Batch30')(x) x = level_3 = DefaultConv2D(32,kernel_initializer=kernel_initializer(30),name='Conv31')(x) x = layers.BatchNormalization(name='Batch31')(x) x = DefaultPooling(name='Pool30')(x) # 32x32 -> 16x16 x = DefaultConv2D(64,kernel_initializer=kernel_initializer(79),name='Conv40')(x) x = layers.BatchNormalization(name='Batch40')(x) x = level_4 = DefaultConv2D(64,kernel_initializer=kernel_initializer(81),name='Conv41')(x) x = layers.BatchNormalization(name='Batch41')(x) x = DefaultPooling(name='Pool40')(x) # 16x16 -> 8x8 #Decoder x = DefaultConv2D(128,kernel_initializer=kernel_initializer(89),name='Conv50')(x) x = layers.BatchNormalization(name='Batch50')(x) x = DefaultConv2D(128,kernel_initializer=kernel_initializer(42),name='Conv51')(x) x = layers.BatchNormalization(name='Batch51')(x) x = upsample(name='Up60')(x) # 8x8 -> 16x16 x = layers.Concatenate(name='Concat60')([level_4,x]) x = DefaultConv2D(64,kernel_initializer=kernel_initializer(91),name='Conv60')(x) x = layers.BatchNormalization(name='Batch60')(x) x = DefaultConv2D(64,kernel_initializer=kernel_initializer(47),name='Conv61')(x) x = layers.BatchNormalization(name='Batch61')(x) x = upsample(name='Up70')(x) # 16x16 -> 32x32 x = layers.Concatenate(name='Concat70')([level_3,x]) x = DefaultConv2D(32,kernel_initializer=kernel_initializer(21),name='Conv70')(x) x = layers.BatchNormalization(name='Batch70')(x) x = DefaultConv2D(32,kernel_initializer=kernel_initializer(96),name='Conv71')(x) x = layers.BatchNormalization(name='Batch71')(x) x = upsample(name='Up80')(x) # 32x32 -> 64x64 x = layers.Concatenate(name='Concat80')([level_2,x]) x = DefaultConv2D(16,kernel_initializer=kernel_initializer(96),name='Conv80')(x) x = layers.BatchNormalization(name='Batch80')(x) x = DefaultConv2D(16,kernel_initializer=kernel_initializer(98),name='Conv81')(x) x = layers.BatchNormalization(name='Batch81')(x) x = upsample(name='Up90')(x) # 64x64 -> 128x128 x = layers.Concatenate(name='Concat90')([level_1,x]) x = DefaultConv2D(8,kernel_initializer=kernel_initializer(35),name='Conv90')(x) x = layers.BatchNormalization(name='Batch90')(x) x = DefaultConv2D(8,kernel_initializer=kernel_initializer(7),name='Conv91')(x) x = layers.BatchNormalization(name='Batch91')(x) x = DefaultConv2D(out_channels,kernel_size=(1,1),activation=out_ActFunction, kernel_initializer=kernel_initializer(42), name='Conv100')(x) model = Model(input,x,name=name) return model
if __name__ == '__main__': model = unet_baseline() model.summary()