Skip to content

Commit 2454b30

Browse files
update comments
1 parent 96e5fc2 commit 2454b30

File tree

7 files changed

+69
-177
lines changed

7 files changed

+69
-177
lines changed

ssseg/modules/models/backbones/beit.py

+3-10
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,7 @@ def forward(self, x):
9595
if self.relative_position_bias_table is not None:
9696
Wh = self.window_size[0]
9797
Ww = self.window_size[1]
98-
relative_position_bias = self.relative_position_bias_table[
99-
self.relative_position_index.view(-1)
100-
].view(Wh * Ww + 1, Wh * Ww + 1, -1)
98+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(Wh * Ww + 1, Wh * Ww + 1, -1)
10199
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
102100
attn = attn + relative_position_bias.unsqueeze(0)
103101
attn = attn.softmax(dim=-1)
@@ -122,13 +120,8 @@ def __init__(self, embed_dims, num_heads, feedforward_channels, attn_drop_rate=0
122120
self.gamma_1 = nn.Parameter(init_values * torch.ones((embed_dims)), requires_grad=True)
123121
self.gamma_2 = nn.Parameter(init_values * torch.ones((embed_dims)), requires_grad=True)
124122
attn_cfg.update(dict(
125-
window_size=window_size,
126-
qk_scale=None,
127-
embed_dims=embed_dims,
128-
num_heads=num_heads,
129-
attn_drop_rate=attn_drop_rate,
130-
proj_drop_rate=0.,
131-
bias=bias,
123+
window_size=window_size, qk_scale=None, embed_dims=embed_dims, num_heads=num_heads,
124+
attn_drop_rate=attn_drop_rate, proj_drop_rate=0., bias=bias,
132125
))
133126
self.attn = BEiTAttention(**attn_cfg)
134127
'''forward'''

ssseg/modules/models/backbones/bisenetv1.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
AUTO_ASSERT_STRUCTURE_TYPES = {}
1818

1919

20-
'''Spatial Path to preserve the spatial size of the original input image and encode affluent spatial information'''
20+
'''SpatialPath'''
2121
class SpatialPath(nn.Module):
2222
def __init__(self, in_channels=3, num_channels_list=(64, 64, 64, 128), norm_cfg=None, act_cfg=None):
2323
super(SpatialPath, self).__init__()
@@ -53,7 +53,7 @@ def forward(self, x):
5353
return x
5454

5555

56-
'''Attention Refinement Module (ARM) to refine the features of each stage'''
56+
'''AttentionRefinementModule'''
5757
class AttentionRefinementModule(nn.Module):
5858
def __init__(self, in_channels, out_channels, norm_cfg=None, act_cfg=None):
5959
super(AttentionRefinementModule, self).__init__()
@@ -76,7 +76,7 @@ def forward(self, x):
7676
return x_out
7777

7878

79-
'''Context Path to provide sufficient receptive field'''
79+
'''ContextPath'''
8080
class ContextPath(nn.Module):
8181
def __init__(self, backbone_cfg, context_channels_list=(128, 256, 512), norm_cfg=None, act_cfg=None):
8282
super(ContextPath, self).__init__()
@@ -125,7 +125,7 @@ def buildbackbone(self, cfg):
125125
return supported_backbones[backbone_type](**cfg)
126126

127127

128-
'''Feature Fusion Module to fuse low level output feature of Spatial Path and high level output feature of Context Path'''
128+
'''FeatureFusionModule'''
129129
class FeatureFusionModule(nn.Module):
130130
def __init__(self, in_channels, out_channels, norm_cfg=None, act_cfg=None):
131131
super(FeatureFusionModule, self).__init__()

ssseg/modules/models/backbones/bisenetv2.py

+25-57
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
AUTO_ASSERT_STRUCTURE_TYPES = {}
1818

1919

20-
'''Detail Branch with wide channels and shallow layers to capture low-level details and generate high-resolution feature representation'''
20+
'''DetailBranch'''
2121
class DetailBranch(nn.Module):
2222
def __init__(self, detail_channels=(64, 64, 128), in_channels=3, norm_cfg=None, act_cfg=None):
2323
super(DetailBranch, self).__init__()
@@ -52,7 +52,7 @@ def forward(self, x):
5252
return x
5353

5454

55-
'''Stem Block at the beginning of Semantic Branch'''
55+
'''StemBlock'''
5656
class StemBlock(nn.Module):
5757
def __init__(self, in_channels=3, out_channels=16, norm_cfg=None, act_cfg=None):
5858
super(StemBlock, self).__init__()
@@ -84,7 +84,7 @@ def forward(self, x):
8484
return x
8585

8686

87-
'''Gather-and-Expansion Layer'''
87+
'''GELayer'''
8888
class GELayer(nn.Module):
8989
def __init__(self, in_channels, out_channels, exp_ratio=6, stride=1, norm_cfg=None, act_cfg=None):
9090
super(GELayer, self).__init__()
@@ -110,15 +110,8 @@ def __init__(self, in_channels, out_channels, exp_ratio=6, stride=1, norm_cfg=No
110110
BuildActivation(act_cfg),
111111
)
112112
self.shortcut = nn.Sequential(DepthwiseSeparableConv2d(
113-
in_channels=in_channels,
114-
out_channels=out_channels,
115-
kernel_size=3,
116-
stride=stride,
117-
padding=1,
118-
dw_norm_cfg=norm_cfg,
119-
dw_act_cfg=None,
120-
pw_norm_cfg=norm_cfg,
121-
pw_act_cfg=None,
113+
in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1,
114+
dw_norm_cfg=norm_cfg, dw_act_cfg=None, pw_norm_cfg=norm_cfg, pw_act_cfg=None,
122115
))
123116
self.conv2 = nn.Sequential(
124117
nn.Conv2d(mid_channel, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
@@ -140,7 +133,7 @@ def forward(self, x):
140133
return x
141134

142135

143-
'''Context Embedding Block for large receptive filed in Semantic Branch'''
136+
'''CEBlock'''
144137
class CEBlock(nn.Module):
145138
def __init__(self, in_channels=3, out_channels=16, norm_cfg=None, act_cfg=None):
146139
super(CEBlock, self).__init__()
@@ -172,7 +165,7 @@ def forward(self, x):
172165
return x
173166

174167

175-
'''Semantic Branch which is lightweight with narrow channels and deep layers to obtain high-level semantic context'''
168+
'''SemanticBranch'''
176169
class SemanticBranch(nn.Module):
177170
def __init__(self, semantic_channels=(16, 32, 64, 128), in_channels=3, exp_ratio=6, norm_cfg=None, act_cfg=None):
178171
super(SemanticBranch, self).__init__()
@@ -187,25 +180,18 @@ def __init__(self, semantic_channels=(16, 32, 64, 128), in_channels=3, exp_ratio
187180
if i == 0:
188181
self.add_module(stage_name, StemBlock(in_channels, semantic_channels[i], norm_cfg=norm_cfg, act_cfg=act_cfg))
189182
elif i == (len(semantic_channels) - 1):
190-
self.add_module(
191-
stage_name,
192-
nn.Sequential(
193-
GELayer(semantic_channels[i - 1], semantic_channels[i], exp_ratio, 2, norm_cfg=norm_cfg, act_cfg=act_cfg),
194-
GELayer(semantic_channels[i], semantic_channels[i], exp_ratio, 1, norm_cfg=norm_cfg, act_cfg=act_cfg),
195-
GELayer(semantic_channels[i], semantic_channels[i], exp_ratio, 1, norm_cfg=norm_cfg, act_cfg=act_cfg),
196-
GELayer(semantic_channels[i], semantic_channels[i], exp_ratio, 1, norm_cfg=norm_cfg, act_cfg=act_cfg),
197-
)
198-
)
183+
self.add_module(stage_name, nn.Sequential(
184+
GELayer(semantic_channels[i - 1], semantic_channels[i], exp_ratio, 2, norm_cfg=norm_cfg, act_cfg=act_cfg),
185+
GELayer(semantic_channels[i], semantic_channels[i], exp_ratio, 1, norm_cfg=norm_cfg, act_cfg=act_cfg),
186+
GELayer(semantic_channels[i], semantic_channels[i], exp_ratio, 1, norm_cfg=norm_cfg, act_cfg=act_cfg),
187+
GELayer(semantic_channels[i], semantic_channels[i], exp_ratio, 1, norm_cfg=norm_cfg, act_cfg=act_cfg),
188+
))
199189
else:
200-
self.add_module(
201-
stage_name,
202-
nn.Sequential(
203-
GELayer(semantic_channels[i - 1], semantic_channels[i], exp_ratio, 2, norm_cfg=norm_cfg, act_cfg=act_cfg),
204-
GELayer(semantic_channels[i], semantic_channels[i], exp_ratio, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
205-
)
206-
)
207-
self.add_module(
208-
f'stage{len(semantic_channels)}_CEBlock',
190+
self.add_module(stage_name, nn.Sequential(
191+
GELayer(semantic_channels[i - 1], semantic_channels[i], exp_ratio, 2, norm_cfg=norm_cfg, act_cfg=act_cfg),
192+
GELayer(semantic_channels[i], semantic_channels[i], exp_ratio, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
193+
))
194+
self.add_module(f'stage{len(semantic_channels)}_CEBlock',
209195
CEBlock(semantic_channels[-1], semantic_channels[-1], norm_cfg=norm_cfg, act_cfg=act_cfg),
210196
)
211197
self.semantic_stages.append(f'stage{len(semantic_channels)}_CEBlock')
@@ -219,7 +205,7 @@ def forward(self, x):
219205
return semantic_outs
220206

221207

222-
'''Bilateral Guided Aggregation Layer to fuse the complementary information from both Detail Branch and Semantic Branch'''
208+
'''BGALayer'''
223209
class BGALayer(nn.Module):
224210
def __init__(self, out_channels=128, align_corners=False, norm_cfg=None, act_cfg=None):
225211
super(BGALayer, self).__init__()
@@ -228,15 +214,8 @@ def __init__(self, out_channels=128, align_corners=False, norm_cfg=None, act_cfg
228214
self.align_corners = align_corners
229215
# define modules
230216
self.detail_dwconv = nn.Sequential(DepthwiseSeparableConv2d(
231-
in_channels=out_channels,
232-
out_channels=out_channels,
233-
kernel_size=3,
234-
stride=1,
235-
padding=1,
236-
dw_norm_cfg=norm_cfg,
237-
dw_act_cfg=None,
238-
pw_norm_cfg=None,
239-
pw_act_cfg=None,
217+
in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1,
218+
dw_norm_cfg=norm_cfg, dw_act_cfg=None, pw_norm_cfg=None, pw_act_cfg=None,
240219
))
241220
self.detail_down = nn.Sequential(
242221
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False),
@@ -248,15 +227,8 @@ def __init__(self, out_channels=128, align_corners=False, norm_cfg=None, act_cfg
248227
BuildNormalization(placeholder=out_channels, norm_cfg=norm_cfg),
249228
)
250229
self.semantic_dwconv = nn.Sequential(DepthwiseSeparableConv2d(
251-
in_channels=out_channels,
252-
out_channels=out_channels,
253-
kernel_size=3,
254-
stride=1,
255-
padding=1,
256-
dw_norm_cfg=norm_cfg,
257-
dw_act_cfg=None,
258-
pw_norm_cfg=None,
259-
pw_act_cfg=None,
230+
in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1,
231+
dw_norm_cfg=norm_cfg, dw_act_cfg=None, pw_norm_cfg=None, pw_act_cfg=None,
260232
))
261233
self.conv = nn.Sequential(
262234
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
@@ -269,14 +241,10 @@ def forward(self, x_d, x_s):
269241
detail_down = self.detail_down(x_d)
270242
semantic_conv = self.semantic_conv(x_s)
271243
semantic_dwconv = self.semantic_dwconv(x_s)
272-
semantic_conv = F.interpolate(
273-
semantic_conv, size=detail_dwconv.shape[2:], mode='bilinear', align_corners=self.align_corners,
274-
)
244+
semantic_conv = F.interpolate(semantic_conv, size=detail_dwconv.shape[2:], mode='bilinear', align_corners=self.align_corners)
275245
fuse_1 = detail_dwconv * torch.sigmoid(semantic_conv)
276246
fuse_2 = detail_down * torch.sigmoid(semantic_dwconv)
277-
fuse_2 = F.interpolate(
278-
fuse_2, size=fuse_1.shape[2:], mode='bilinear', align_corners=self.align_corners
279-
)
247+
fuse_2 = F.interpolate(fuse_2, size=fuse_1.shape[2:], mode='bilinear', align_corners=self.align_corners)
280248
output = self.conv(fuse_1 + fuse_2)
281249
return output
282250

ssseg/modules/models/backbones/cgnet.py

+7-19
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
AUTO_ASSERT_STRUCTURE_TYPES = {}
1717

1818

19-
'''Global Context Extractor for CGNet'''
19+
'''GlobalContextExtractor'''
2020
class GlobalContextExtractor(nn.Module):
2121
def __init__(self, channels, reduction=16):
2222
super(GlobalContextExtractor, self).__init__()
@@ -38,7 +38,7 @@ def forward(self, x):
3838
return x * y
3939

4040

41-
'''Context Guided Block for CGNet'''
41+
'''ContextGuidedBlock'''
4242
class ContextGuidedBlock(nn.Module):
4343
def __init__(self, in_channels, out_channels, dilation=2, reduction=16, skip_connect=True, downsample=False, norm_cfg=None, act_cfg=None):
4444
super(ContextGuidedBlock, self).__init__()
@@ -80,7 +80,7 @@ def forward(self, x):
8080
return out
8181

8282

83-
'''Downsampling module for CGNet'''
83+
'''InputInjection'''
8484
class InputInjection(nn.Module):
8585
def __init__(self, num_downsamplings):
8686
super(InputInjection, self).__init__()
@@ -144,14 +144,8 @@ def __init__(self, structure_type, in_channels=3, num_channels=(32, 64, 128), nu
144144
self.level1 = nn.ModuleList()
145145
for i in range(num_blocks[0]):
146146
self.level1.append(ContextGuidedBlock(
147-
in_channels=cur_channels if i == 0 else num_channels[1],
148-
out_channels=num_channels[1],
149-
dilation=dilations[0],
150-
reduction=reductions[0],
151-
skip_connect=True,
152-
downsample=(i == 0),
153-
norm_cfg=norm_cfg,
154-
act_cfg=act_cfg,
147+
in_channels=cur_channels if i == 0 else num_channels[1], out_channels=num_channels[1], dilation=dilations[0],
148+
reduction=reductions[0], skip_connect=True, downsample=(i == 0), norm_cfg=norm_cfg, act_cfg=act_cfg,
155149
))
156150
cur_channels = 2 * num_channels[1] + in_channels
157151
self.norm_prelu_1 = nn.Sequential(
@@ -162,14 +156,8 @@ def __init__(self, structure_type, in_channels=3, num_channels=(32, 64, 128), nu
162156
self.level2 = nn.ModuleList()
163157
for i in range(num_blocks[1]):
164158
self.level2.append(ContextGuidedBlock(
165-
in_channels=cur_channels if i == 0 else num_channels[2],
166-
out_channels=num_channels[2],
167-
dilation=dilations[1],
168-
reduction=reductions[1],
169-
skip_connect=True,
170-
downsample=(i == 0),
171-
norm_cfg=norm_cfg,
172-
act_cfg=act_cfg,
159+
in_channels=cur_channels if i == 0 else num_channels[2], out_channels=num_channels[2], dilation=dilations[1],
160+
reduction=reductions[1], skip_connect=True, downsample=(i == 0), norm_cfg=norm_cfg, act_cfg=act_cfg,
173161
))
174162
cur_channels = 2 * num_channels[2]
175163
self.norm_prelu_2 = nn.Sequential(

0 commit comments

Comments
 (0)