@@ -26,7 +26,7 @@ def __init__(
26
26
encoder_name : str = 'resnet34' ,
27
27
encoder_depth : int = 5 ,
28
28
encoder_weights : str | None = 'imagenet' ,
29
- decoder_use_batchnorm : bool = True ,
29
+ decoder_use_batchnorm : bool | str | dict [ str , Any ] = 'batchnorm' ,
30
30
decoder_channels : Sequence [int ] = (256 , 128 , 64 , 32 , 16 ),
31
31
decoder_attention_type : str | None = None ,
32
32
in_channels : int = 3 ,
@@ -50,10 +50,26 @@ def __init__(
50
50
decoder_channels: List of integers which specify **in_channels**
51
51
parameter for convolutions used in decoder. Length of the list
52
52
should be the same as **encoder_depth**
53
- decoder_use_batchnorm: If **True**, BatchNorm2d layer between
54
- Conv2D and Activation layers is used. If **"inplace"** InplaceABN
55
- will be used, allows to decrease memory consumption. Available
56
- options are **True, False, "inplace"**
53
+ decoder_use_batchnorm: Specifies normalization between Conv2D and
54
+ activation. Accepts the following types:
55
+
56
+ - **True**: Defaults to `"batchnorm"`.
57
+ - **False**: No normalization (`nn.Identity`).
58
+ - **str**: Specifies normalization type using default parameters.
59
+ Available values: `"batchnorm"`, `"identity"`, `"layernorm"`,
60
+ `"instancenorm"`, `"inplace"`.
61
+ - **dict**: Fully customizable normalization settings. Structure:
62
+ ```python
63
+ {"type": <norm_type>, **kwargs}
64
+ ```
65
+ where `norm_name` corresponds to normalization type (see above), and
66
+ `kwargs` are passed directly to the normalization layer as defined in
67
+ PyTorch documentation.
68
+
69
+ **Example**:
70
+ ```python
71
+ decoder_use_norm={"type": "layernorm", "eps": 1e-2}
72
+ ```
57
73
decoder_attention_type: Attention module used in decoder of the model.
58
74
Available options are **None** and **scse**. SCSE paper
59
75
https://arxiv.org/abs/1808.08127
@@ -79,9 +95,9 @@ def __init__(
79
95
encoder_channels = encoder_out_channels ,
80
96
decoder_channels = decoder_channels ,
81
97
n_blocks = encoder_depth ,
82
- use_batchnorm = decoder_use_batchnorm ,
83
- center = True if encoder_name .startswith ('vgg' ) else False ,
98
+ use_norm = decoder_use_batchnorm ,
84
99
attention_type = decoder_attention_type ,
100
+ add_center_block = True if encoder_name .startswith ('vgg' ) else False ,
85
101
)
86
102
87
103
self .segmentation_head = smp .base .SegmentationHead (
@@ -111,7 +127,7 @@ def forward(self, x: Tensor) -> Tensor:
111
127
for i in range (1 , len (features1 ))
112
128
]
113
129
features .insert (0 , features2 [0 ])
114
- decoder_output = self .decoder (* features )
130
+ decoder_output = self .decoder (features )
115
131
masks : Tensor = self .segmentation_head (decoder_output )
116
132
return masks
117
133
@@ -150,6 +166,6 @@ def forward(self, x: Tensor) -> Tensor:
150
166
features1 , features2 = self .encoder (x1 ), self .encoder (x2 )
151
167
features = [features2 [i ] - features1 [i ] for i in range (1 , len (features1 ))]
152
168
features .insert (0 , features2 [0 ])
153
- decoder_output = self .decoder (* features )
169
+ decoder_output = self .decoder (features )
154
170
masks : Tensor = self .segmentation_head (decoder_output )
155
171
return masks
0 commit comments