17
17
AUTO_ASSERT_STRUCTURE_TYPES = {}
18
18
19
19
20
- '''Detail Branch with wide channels and shallow layers to capture low-level details and generate high-resolution feature representation '''
20
+ '''DetailBranch '''
21
21
class DetailBranch (nn .Module ):
22
22
def __init__ (self , detail_channels = (64 , 64 , 128 ), in_channels = 3 , norm_cfg = None , act_cfg = None ):
23
23
super (DetailBranch , self ).__init__ ()
@@ -52,7 +52,7 @@ def forward(self, x):
52
52
return x
53
53
54
54
55
- '''Stem Block at the beginning of Semantic Branch '''
55
+ '''StemBlock '''
56
56
class StemBlock (nn .Module ):
57
57
def __init__ (self , in_channels = 3 , out_channels = 16 , norm_cfg = None , act_cfg = None ):
58
58
super (StemBlock , self ).__init__ ()
@@ -84,7 +84,7 @@ def forward(self, x):
84
84
return x
85
85
86
86
87
- '''Gather-and-Expansion Layer '''
87
+ '''GELayer '''
88
88
class GELayer (nn .Module ):
89
89
def __init__ (self , in_channels , out_channels , exp_ratio = 6 , stride = 1 , norm_cfg = None , act_cfg = None ):
90
90
super (GELayer , self ).__init__ ()
@@ -110,15 +110,8 @@ def __init__(self, in_channels, out_channels, exp_ratio=6, stride=1, norm_cfg=No
110
110
BuildActivation (act_cfg ),
111
111
)
112
112
self .shortcut = nn .Sequential (DepthwiseSeparableConv2d (
113
- in_channels = in_channels ,
114
- out_channels = out_channels ,
115
- kernel_size = 3 ,
116
- stride = stride ,
117
- padding = 1 ,
118
- dw_norm_cfg = norm_cfg ,
119
- dw_act_cfg = None ,
120
- pw_norm_cfg = norm_cfg ,
121
- pw_act_cfg = None ,
113
+ in_channels = in_channels , out_channels = out_channels , kernel_size = 3 , stride = stride , padding = 1 ,
114
+ dw_norm_cfg = norm_cfg , dw_act_cfg = None , pw_norm_cfg = norm_cfg , pw_act_cfg = None ,
122
115
))
123
116
self .conv2 = nn .Sequential (
124
117
nn .Conv2d (mid_channel , out_channels , kernel_size = 1 , stride = 1 , padding = 0 , bias = False ),
@@ -140,7 +133,7 @@ def forward(self, x):
140
133
return x
141
134
142
135
143
- '''Context Embedding Block for large receptive filed in Semantic Branch '''
136
+ '''CEBlock '''
144
137
class CEBlock (nn .Module ):
145
138
def __init__ (self , in_channels = 3 , out_channels = 16 , norm_cfg = None , act_cfg = None ):
146
139
super (CEBlock , self ).__init__ ()
@@ -172,7 +165,7 @@ def forward(self, x):
172
165
return x
173
166
174
167
175
- '''Semantic Branch which is lightweight with narrow channels and deep layers to obtain high-level semantic context '''
168
+ '''SemanticBranch '''
176
169
class SemanticBranch (nn .Module ):
177
170
def __init__ (self , semantic_channels = (16 , 32 , 64 , 128 ), in_channels = 3 , exp_ratio = 6 , norm_cfg = None , act_cfg = None ):
178
171
super (SemanticBranch , self ).__init__ ()
@@ -187,25 +180,18 @@ def __init__(self, semantic_channels=(16, 32, 64, 128), in_channels=3, exp_ratio
187
180
if i == 0 :
188
181
self .add_module (stage_name , StemBlock (in_channels , semantic_channels [i ], norm_cfg = norm_cfg , act_cfg = act_cfg ))
189
182
elif i == (len (semantic_channels ) - 1 ):
190
- self .add_module (
191
- stage_name ,
192
- nn .Sequential (
193
- GELayer (semantic_channels [i - 1 ], semantic_channels [i ], exp_ratio , 2 , norm_cfg = norm_cfg , act_cfg = act_cfg ),
194
- GELayer (semantic_channels [i ], semantic_channels [i ], exp_ratio , 1 , norm_cfg = norm_cfg , act_cfg = act_cfg ),
195
- GELayer (semantic_channels [i ], semantic_channels [i ], exp_ratio , 1 , norm_cfg = norm_cfg , act_cfg = act_cfg ),
196
- GELayer (semantic_channels [i ], semantic_channels [i ], exp_ratio , 1 , norm_cfg = norm_cfg , act_cfg = act_cfg ),
197
- )
198
- )
183
+ self .add_module (stage_name , nn .Sequential (
184
+ GELayer (semantic_channels [i - 1 ], semantic_channels [i ], exp_ratio , 2 , norm_cfg = norm_cfg , act_cfg = act_cfg ),
185
+ GELayer (semantic_channels [i ], semantic_channels [i ], exp_ratio , 1 , norm_cfg = norm_cfg , act_cfg = act_cfg ),
186
+ GELayer (semantic_channels [i ], semantic_channels [i ], exp_ratio , 1 , norm_cfg = norm_cfg , act_cfg = act_cfg ),
187
+ GELayer (semantic_channels [i ], semantic_channels [i ], exp_ratio , 1 , norm_cfg = norm_cfg , act_cfg = act_cfg ),
188
+ ))
199
189
else :
200
- self .add_module (
201
- stage_name ,
202
- nn .Sequential (
203
- GELayer (semantic_channels [i - 1 ], semantic_channels [i ], exp_ratio , 2 , norm_cfg = norm_cfg , act_cfg = act_cfg ),
204
- GELayer (semantic_channels [i ], semantic_channels [i ], exp_ratio , 1 , norm_cfg = norm_cfg , act_cfg = act_cfg )
205
- )
206
- )
207
- self .add_module (
208
- f'stage{ len (semantic_channels )} _CEBlock' ,
190
+ self .add_module (stage_name , nn .Sequential (
191
+ GELayer (semantic_channels [i - 1 ], semantic_channels [i ], exp_ratio , 2 , norm_cfg = norm_cfg , act_cfg = act_cfg ),
192
+ GELayer (semantic_channels [i ], semantic_channels [i ], exp_ratio , 1 , norm_cfg = norm_cfg , act_cfg = act_cfg )
193
+ ))
194
+ self .add_module (f'stage{ len (semantic_channels )} _CEBlock' ,
209
195
CEBlock (semantic_channels [- 1 ], semantic_channels [- 1 ], norm_cfg = norm_cfg , act_cfg = act_cfg ),
210
196
)
211
197
self .semantic_stages .append (f'stage{ len (semantic_channels )} _CEBlock' )
@@ -219,7 +205,7 @@ def forward(self, x):
219
205
return semantic_outs
220
206
221
207
222
- '''Bilateral Guided Aggregation Layer to fuse the complementary information from both Detail Branch and Semantic Branch '''
208
+ '''BGALayer '''
223
209
class BGALayer (nn .Module ):
224
210
def __init__ (self , out_channels = 128 , align_corners = False , norm_cfg = None , act_cfg = None ):
225
211
super (BGALayer , self ).__init__ ()
@@ -228,15 +214,8 @@ def __init__(self, out_channels=128, align_corners=False, norm_cfg=None, act_cfg
228
214
self .align_corners = align_corners
229
215
# define modules
230
216
self .detail_dwconv = nn .Sequential (DepthwiseSeparableConv2d (
231
- in_channels = out_channels ,
232
- out_channels = out_channels ,
233
- kernel_size = 3 ,
234
- stride = 1 ,
235
- padding = 1 ,
236
- dw_norm_cfg = norm_cfg ,
237
- dw_act_cfg = None ,
238
- pw_norm_cfg = None ,
239
- pw_act_cfg = None ,
217
+ in_channels = out_channels , out_channels = out_channels , kernel_size = 3 , stride = 1 , padding = 1 ,
218
+ dw_norm_cfg = norm_cfg , dw_act_cfg = None , pw_norm_cfg = None , pw_act_cfg = None ,
240
219
))
241
220
self .detail_down = nn .Sequential (
242
221
nn .Conv2d (out_channels , out_channels , kernel_size = 3 , stride = 2 , padding = 1 , bias = False ),
@@ -248,15 +227,8 @@ def __init__(self, out_channels=128, align_corners=False, norm_cfg=None, act_cfg
248
227
BuildNormalization (placeholder = out_channels , norm_cfg = norm_cfg ),
249
228
)
250
229
self .semantic_dwconv = nn .Sequential (DepthwiseSeparableConv2d (
251
- in_channels = out_channels ,
252
- out_channels = out_channels ,
253
- kernel_size = 3 ,
254
- stride = 1 ,
255
- padding = 1 ,
256
- dw_norm_cfg = norm_cfg ,
257
- dw_act_cfg = None ,
258
- pw_norm_cfg = None ,
259
- pw_act_cfg = None ,
230
+ in_channels = out_channels , out_channels = out_channels , kernel_size = 3 , stride = 1 , padding = 1 ,
231
+ dw_norm_cfg = norm_cfg , dw_act_cfg = None , pw_norm_cfg = None , pw_act_cfg = None ,
260
232
))
261
233
self .conv = nn .Sequential (
262
234
nn .Conv2d (out_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ),
@@ -269,14 +241,10 @@ def forward(self, x_d, x_s):
269
241
detail_down = self .detail_down (x_d )
270
242
semantic_conv = self .semantic_conv (x_s )
271
243
semantic_dwconv = self .semantic_dwconv (x_s )
272
- semantic_conv = F .interpolate (
273
- semantic_conv , size = detail_dwconv .shape [2 :], mode = 'bilinear' , align_corners = self .align_corners ,
274
- )
244
+ semantic_conv = F .interpolate (semantic_conv , size = detail_dwconv .shape [2 :], mode = 'bilinear' , align_corners = self .align_corners )
275
245
fuse_1 = detail_dwconv * torch .sigmoid (semantic_conv )
276
246
fuse_2 = detail_down * torch .sigmoid (semantic_dwconv )
277
- fuse_2 = F .interpolate (
278
- fuse_2 , size = fuse_1 .shape [2 :], mode = 'bilinear' , align_corners = self .align_corners
279
- )
247
+ fuse_2 = F .interpolate (fuse_2 , size = fuse_1 .shape [2 :], mode = 'bilinear' , align_corners = self .align_corners )
280
248
output = self .conv (fuse_1 + fuse_2 )
281
249
return output
282
250
0 commit comments