@@ -21,57 +21,62 @@ class Unet():
21
21
https://arxiv.org/pdf/1505.04597
22
22
"""
23
23
24
- def __init__ (self ):
25
- pass
26
-
27
- def Build_UNetwork (self , input_shape = (572 , 572 , 1 ), filters = [64 , 128 , 256 , 512 , 1024 ]):
28
-
24
+ def __init__ (self , input_shape , filters , padding ):
29
25
"""
30
- Builds the Unet Model network.
26
+
27
+ Initialize the Unet framework and the model parameters - input_shape,
28
+ filters and padding type.
31
29
32
30
Args:
33
31
input_shape: The shape of the input to the network. A tuple comprising of (img_height, img_width, channels).
34
- Default shape is (572, 572, 1).
32
+ Original paper implementation is (572, 572, 1).
35
33
filters: a collection of filters denoting the number of components to be used at each blocks along the
36
- contracting and expansive paths. The default number of filters along the contracting and expansive paths are
37
- [64, 128, 256, 512, 1024].
34
+ contracting and expansive paths. The original paper implementation for number of filters along the
35
+ contracting and expansive paths are [64, 128, 256, 512, 1024].
36
+ padding: the padding type to be used during the convolution step. The original paper used unpadded convolutions
37
+ which is of type "valid".
38
38
39
39
**Remarks: The default values are as per the implementation in the original paper @ https://arxiv.org/pdf/1505.04597
40
-
40
+
41
+ """
42
+ self .input_shape = input_shape
43
+ self .filters = filters
44
+ self .padding = padding
45
+
46
+ def Build_UNetwork (self ):
47
+
48
+ """
49
+ Builds the Unet Model network.
50
+
51
+ Args:
52
+ None
53
+
41
54
Return:
42
55
The Unet Model.
43
56
44
- **Note - If the total number of filters are not sufficient to implement each block along the contracting
45
- and expansive path, then the return value is None.
46
-
47
57
"""
48
58
49
- if len (filters ) != 5 :
50
- print ("There are not sufficient filters to implement each block of the UNet model.\n Recheck the filters." )
51
- return None
52
59
53
- UnetInput = Input (input_shape )
60
+ UnetInput = Input (self . input_shape )
54
61
55
62
# the contracting path.
56
63
# the last item in the filetrs collection points to the number of filters in the bottleneck block.
57
- # so we loop till the 4th item.
58
- for num_filters in filters [:4 ]:
59
- conv1 , pool1 = UnetUtils .contracting_block (input_layer = UnetInput , filters = num_filters )
60
- conv2 , pool2 = UnetUtils .contracting_block (input_layer = pool1 , filters = num_filters )
61
- conv3 , pool3 = UnetUtils .contracting_block (input_layer = pool2 , filters = num_filters )
62
- conv4 , pool4 = UnetUtils .contracting_block (input_layer = pool3 , filters = num_filters )
64
+ conv1 , pool1 = UnetUtils .contracting_block (input_layer = UnetInput , filters = self .filters [0 ], padding = self .padding )
65
+ conv2 , pool2 = UnetUtils .contracting_block (input_layer = pool1 , filters = self .filters [1 ], padding = self .padding )
66
+ conv3 , pool3 = UnetUtils .contracting_block (input_layer = pool2 , filters = self .filters [2 ], padding = self .padding )
67
+ conv4 , pool4 = UnetUtils .contracting_block (input_layer = pool3 , filters = self .filters [3 ], padding = self .padding )
63
68
64
69
# bottleneck block connecting the contracting and the expansive paths.
65
- bottleNeck = UnetUtils .bottleneck_block (pool4 , filters = filters [- 1 ] )
70
+ bottleNeck = UnetUtils .bottleneck_block (pool4 , filters = self . filters [4 ], padding = self . padding )
66
71
67
- # the expansive path.essentially we loop the reversed filter list leaving out the last item.
68
- for num_filters in reversed (filters [:- 1 ]):
69
- upConv1 = UnetUtils .expansive_block (bottleNeck , conv4 , filters = num_filters )
70
- upConv2 = UnetUtils .expansive_block (upConv1 , conv3 , filters = num_filters )
71
- upConv3 = UnetUtils .expansive_block (upConv2 , conv2 , filters = num_filters )
72
- upConv4 = UnetUtils .expansive_block (upConv3 , conv1 , filters = num_filters )
72
+ # the expansive path.
73
+ upConv1 = UnetUtils .expansive_block (bottleNeck , conv4 , filters = self .filters [3 ], padding = self .padding )
74
+ upConv2 = UnetUtils .expansive_block (upConv1 , conv3 , filters = self .filters [2 ], padding = self .padding )
75
+ upConv3 = UnetUtils .expansive_block (upConv2 , conv2 , filters = self .filters [1 ], padding = self .padding )
76
+ upConv4 = UnetUtils .expansive_block (upConv3 , conv1 , filters = self .filters [0 ], padding = self .padding )
73
77
74
- UnetOutput = Conv2D (1 , (1 , 1 ), padding = "valid" , activation = "sigmoid" )(upConv4 )
78
+ UnetOutput = Conv2D (1 , (1 , 1 ), padding = self .padding , activation = tf .math .sigmoid )(upConv4 )
79
+
75
80
model = Model (UnetInput , UnetOutput , name = "UNet" )
76
81
77
82
return model
0 commit comments