Skip to content

Commit 374815d

Browse files
Updated to the UNet framework
Extended the implementation be more generic to support variations to the network in terms of the inputs and filter used at each convolution blocks/steps.
1 parent 4aeb23c commit 374815d

File tree

3 files changed

+60
-44
lines changed

3 files changed

+60
-44
lines changed

UNet - Biomedical_Segmentation.ipynb

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,7 @@
7070
"outputs": [],
7171
"source": [
7272
"# import the UNet design framework module.\n",
73-
"from Unet import Unet\n",
74-
"unet = Unet()"
73+
"from Unet import Unet"
7574
]
7675
},
7776
{
@@ -80,6 +79,13 @@
8079
"metadata": {},
8180
"outputs": [],
8281
"source": [
82+
"# Initialize the Unet with the default parameters. \n",
83+
"# The default params are the one that were used in the original paper.\n",
84+
"# Input shape - (572, 572, 1), \n",
85+
"# filters 64, 128, 256, 512, 1024 at each convolutional block and \n",
86+
"# unpadded convolutions.\n",
87+
"unet = Unet(input_shape = (572, 572, 1), filters = [64, 128, 256, 512, 1024], padding = \"valid\")\n",
88+
"\n",
8389
"# call the build netowrk API to build the network.\n",
8490
"model = unet.Build_UNetwork()"
8591
]
@@ -182,18 +188,20 @@
182188
],
183189
"source": [
184190
"# compile & summarize the model\n",
185-
"unet.CompileAndSummarizeModel(model = model)"
191+
"if model is not None:\n",
192+
" unet.CompileAndSummarizeModel(model = model)"
186193
]
187194
},
188195
{
189196
"cell_type": "code",
190-
"execution_count": 4,
197+
"execution_count": null,
191198
"metadata": {
192199
"scrolled": false
193200
},
194201
"outputs": [],
195202
"source": [
196-
"unet.plotModel(model, dpi = 112)"
203+
"if model is not None:\n",
204+
" unet.plotModel(model, dpi = 112)"
197205
]
198206
},
199207
{

Unet.py

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,57 +21,62 @@ class Unet():
2121
https://arxiv.org/pdf/1505.04597
2222
"""
2323

24-
def __init__(self):
25-
pass
26-
27-
def Build_UNetwork(self, input_shape = (572, 572, 1), filters = [64, 128, 256, 512, 1024]):
28-
24+
def __init__(self, input_shape, filters, padding):
2925
"""
30-
Builds the Unet Model network.
26+
27+
Initialize the Unet framework and the model parameters - input_shape,
28+
filters and padding type.
3129
3230
Args:
3331
input_shape: The shape of the input to the network. A tuple comprising of (img_height, img_width, channels).
34-
Default shape is (572, 572, 1).
32+
Original paper implementation is (572, 572, 1).
3533
filters: a collection of filters denoting the number of components to be used at each blocks along the
36-
contracting and expansive paths. The default number of filters along the contracting and expansive paths are
37-
[64, 128, 256, 512, 1024].
34+
contracting and expansive paths. The original paper implementation for number of filters along the
35+
contracting and expansive paths are [64, 128, 256, 512, 1024].
36+
padding: the padding type to be used during the convolution step. The original paper used unpadded convolutions
37+
which is of type "valid".
3838
3939
**Remarks: The default values are as per the implementation in the original paper @ https://arxiv.org/pdf/1505.04597
40-
40+
41+
"""
42+
self.input_shape = input_shape
43+
self.filters = filters
44+
self.padding = padding
45+
46+
def Build_UNetwork(self):
47+
48+
"""
49+
Builds the Unet Model network.
50+
51+
Args:
52+
None
53+
4154
Return:
4255
The Unet Model.
4356
44-
**Note - If the total number of filters are not sufficient to implement each block along the contracting
45-
and expansive path, then the return value is None.
46-
4757
"""
4858

49-
if len(filters) != 5:
50-
print("There are not sufficient filters to implement each block of the UNet model.\nRecheck the filters.")
51-
return None
5259

53-
UnetInput = Input(input_shape)
60+
UnetInput = Input(self.input_shape)
5461

5562
# the contracting path.
5663
# the last item in the filetrs collection points to the number of filters in the bottleneck block.
57-
# so we loop till the 4th item.
58-
for num_filters in filters[:4]:
59-
conv1, pool1 = UnetUtils.contracting_block(input_layer = UnetInput, filters = num_filters)
60-
conv2, pool2 = UnetUtils.contracting_block(input_layer = pool1, filters = num_filters)
61-
conv3, pool3 = UnetUtils.contracting_block(input_layer = pool2, filters = num_filters)
62-
conv4, pool4 = UnetUtils.contracting_block(input_layer = pool3, filters = num_filters)
64+
conv1, pool1 = UnetUtils.contracting_block(input_layer = UnetInput, filters = self.filters[0], padding = self.padding)
65+
conv2, pool2 = UnetUtils.contracting_block(input_layer = pool1, filters = self.filters[1], padding = self.padding)
66+
conv3, pool3 = UnetUtils.contracting_block(input_layer = pool2, filters = self.filters[2], padding = self.padding)
67+
conv4, pool4 = UnetUtils.contracting_block(input_layer = pool3, filters = self.filters[3], padding = self.padding)
6368

6469
# bottleneck block connecting the contracting and the expansive paths.
65-
bottleNeck = UnetUtils.bottleneck_block(pool4, filters = filters[-1])
70+
bottleNeck = UnetUtils.bottleneck_block(pool4, filters = self.filters[4], padding = self.padding)
6671

67-
# the expansive path.essentially we loop the reversed filter list leaving out the last item.
68-
for num_filters in reversed(filters[:-1]):
69-
upConv1 = UnetUtils.expansive_block(bottleNeck, conv4, filters = num_filters)
70-
upConv2 = UnetUtils.expansive_block(upConv1, conv3, filters = num_filters)
71-
upConv3 = UnetUtils.expansive_block(upConv2, conv2, filters = num_filters)
72-
upConv4 = UnetUtils.expansive_block(upConv3, conv1, filters = num_filters)
72+
# the expansive path.
73+
upConv1 = UnetUtils.expansive_block(bottleNeck, conv4, filters = self.filters[3], padding = self.padding)
74+
upConv2 = UnetUtils.expansive_block(upConv1, conv3, filters = self.filters[2], padding = self.padding)
75+
upConv3 = UnetUtils.expansive_block(upConv2, conv2, filters = self.filters[1], padding = self.padding)
76+
upConv4 = UnetUtils.expansive_block(upConv3, conv1, filters = self.filters[0], padding = self.padding)
7377

74-
UnetOutput = Conv2D(1, (1, 1), padding = "valid", activation = "sigmoid")(upConv4)
78+
UnetOutput = Conv2D(1, (1, 1), padding = self.padding, activation = tf.math.sigmoid)(upConv4)
79+
7580
model = Model(UnetInput, UnetOutput, name = "UNet")
7681

7782
return model

UnetUtils.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class UnetUtils():
2121
def __init__(self):
2222
pass
2323

24-
def contracting_block(self, input_layer, filters, kernel_size = 3, padding = "valid"):
24+
def contracting_block(self, input_layer, filters, padding, kernel_size = 3):
2525

2626
"""
2727
UNet Contracting block
@@ -56,7 +56,7 @@ def contracting_block(self, input_layer, filters, kernel_size = 3, padding = "va
5656

5757
return conv, pool
5858

59-
def bottleneck_block(self, input_layer, filters, kernel_size = 3, padding = "valid", strides = 1):
59+
def bottleneck_block(self, input_layer, filters, padding, kernel_size = 3, strides = 1):
6060

6161
"""
6262
UNet bottleneck block
@@ -90,7 +90,7 @@ def bottleneck_block(self, input_layer, filters, kernel_size = 3, padding = "val
9090

9191
return conv
9292

93-
def expansive_block(self, input_layer, skip_conn_layer, filters, kernel_size = 3, padding = "valid", strides = 1):
93+
def expansive_block(self, input_layer, skip_conn_layer, filters, padding, kernel_size = 3, strides = 1):
9494

9595
"""
9696
UNet expansive (upsample) block.
@@ -123,12 +123,15 @@ def expansive_block(self, input_layer, skip_conn_layer, filters, kernel_size = 3
123123
strides = 2,
124124
padding = padding)(input_layer)
125125

126-
# crop the spurce feature map so that the skip connection can be established.
127-
# Cropping is necessary due to the loss of border pixels in every convolution.
128-
cropped = self.crop_tensor(skip_conn_layer, transConv)
129-
126+
# crop the source feature map so that the skip connection can be established.
127+
# the original paper implemented unpadded convolutions. So cropping is necessary
128+
# due to the loss of border pixels in every convolution.
130129
# establish the skip connections.
131-
concat = Concatenate()([transConv, cropped])
130+
if padding == "valid":
131+
cropped = self.crop_tensor(skip_conn_layer, transConv)
132+
concat = Concatenate()([transConv, cropped])
133+
else:
134+
concat = Concatenate()([transConv, skip_conn_layer])
132135

133136
# two 3x3 convolutions, each followed by a ReLU
134137
up_conv = Conv2D(filters = filters,

0 commit comments

Comments
 (0)