3
3
# Adopted from https://github.com/tmp-iclr/convmixer
4
4
# Home for convmixer: https://github.com/locuslab/convmixer
5
5
from collections import OrderedDict
6
- from typing import Callable
6
+ from typing import Callable , List , Optional , Union
7
+
7
8
import torch .nn as nn
9
+ from torch import TensorType
8
10
9
11
10
12
class Residual (nn .Module ):
11
- def __init__ (self , fn ):
13
+ def __init__ (self , fn : Callable [[ TensorType ], TensorType ] ):
12
14
super ().__init__ ()
13
15
self .fn = fn
14
16
15
- def forward (self , x ) :
17
+ def forward (self , x : TensorType ) -> TensorType :
16
18
return self .fn (x ) + x
17
19
18
20
19
21
# As original version, act_fn as argument.
20
- def ConvMixerOriginal (dim , depth ,
21
- kernel_size = 9 , patch_size = 7 , n_classes = 1000 ,
22
- act_fn = nn .GELU ()):
22
+ def ConvMixerOriginal (
23
+ dim : int ,
24
+ depth : int ,
25
+ kernel_size : int = 9 ,
26
+ patch_size : int = 7 ,
27
+ n_classes : int = 1000 ,
28
+ act_fn : nn .Module = nn .GELU (),
29
+ ):
23
30
return nn .Sequential (
24
31
nn .Conv2d (3 , dim , kernel_size = patch_size , stride = patch_size ),
25
32
act_fn ,
26
33
nn .BatchNorm2d (dim ),
27
- * [nn .Sequential (
28
- Residual (nn .Sequential (
29
- nn .Conv2d (dim , dim , kernel_size , groups = dim , padding = "same" ),
34
+ * [
35
+ nn .Sequential (
36
+ Residual (
37
+ nn .Sequential (
38
+ nn .Conv2d (dim , dim , kernel_size , groups = dim , padding = "same" ),
39
+ act_fn ,
40
+ nn .BatchNorm2d (dim ),
41
+ )
42
+ ),
43
+ nn .Conv2d (dim , dim , kernel_size = 1 ),
30
44
act_fn ,
31
- nn .BatchNorm2d (dim )
32
- )),
33
- nn .Conv2d (dim , dim , kernel_size = 1 ),
34
- act_fn ,
35
- nn .BatchNorm2d (dim )
36
- ) for i in range (depth )],
45
+ nn .BatchNorm2d (dim ),
46
+ )
47
+ for _i in range (depth )
48
+ ],
37
49
nn .AdaptiveAvgPool2d ((1 , 1 )),
38
50
nn .Flatten (),
39
51
nn .Linear (dim , n_classes )
@@ -43,15 +55,35 @@ def ConvMixerOriginal(dim, depth,
43
55
class ConvLayer (nn .Sequential ):
44
56
"""Basic conv layers block"""
45
57
46
- def __init__ (self , in_channels , out_channels , kernel_size , stride = 1 ,
47
- act_fn = nn .GELU (), padding = 0 , groups = 1 ,
48
- bn_1st = False , pre_act = False ):
58
+ def __init__ (
59
+ self ,
60
+ in_channels : int ,
61
+ out_channels : int ,
62
+ kernel_size : Union [int , tuple [int , int ]],
63
+ stride : int = 1 ,
64
+ act_fn : nn .Module = nn .GELU (),
65
+ padding : Union [int , str ] = 0 ,
66
+ groups : int = 1 ,
67
+ bn_1st : bool = False ,
68
+ pre_act : bool = False ,
69
+ ):
49
70
50
- conv_layer = [('conv' , nn .Conv2d (in_channels , out_channels , kernel_size , stride = stride ,
51
- padding = padding , groups = groups ))]
52
- act_bn = [
53
- ('act_fn' , act_fn ),
54
- ('bn' , nn .BatchNorm2d (out_channels ))
71
+ conv_layer : List [tuple [str , nn .Module ]] = [
72
+ (
73
+ "conv" ,
74
+ nn .Conv2d (
75
+ in_channels ,
76
+ out_channels ,
77
+ kernel_size ,
78
+ stride = stride ,
79
+ padding = padding ,
80
+ groups = groups ,
81
+ ),
82
+ )
83
+ ]
84
+ act_bn : List [tuple [str , nn .Module ]] = [
85
+ ("act_fn" , act_fn ),
86
+ ("bn" , nn .BatchNorm2d (out_channels )),
55
87
]
56
88
if bn_1st :
57
89
act_bn .reverse ()
@@ -64,45 +96,79 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
64
96
65
97
66
98
class ConvMixer (nn .Sequential ):
67
-
68
- def __init__ (self , dim : int , depth : int ,
69
- kernel_size : int = 9 , patch_size : int = 7 , n_classes : int = 1000 ,
70
- act_fn : nn .Module = nn .GELU (),
71
- stem : nn .Module = None ,
72
- bn_1st : bool = False , pre_act : bool = False ,
73
- init_func : Callable = None ):
99
+ def __init__ (
100
+ self ,
101
+ dim : int ,
102
+ depth : int ,
103
+ kernel_size : int = 9 ,
104
+ patch_size : int = 7 ,
105
+ n_classes : int = 1000 ,
106
+ act_fn : nn .Module = nn .GELU (),
107
+ stem : Optional [nn .Module ] = None ,
108
+ in_chans : int = 3 ,
109
+ bn_1st : bool = False ,
110
+ pre_act : bool = False ,
111
+ init_func : Optional [Callable [[nn .Module ], None ]] = None ,
112
+ ):
74
113
"""ConvMixer constructor.
75
114
Adopted from https://github.com/tmp-iclr/convmixer
76
115
77
116
Args:
78
- dim (int): Dimention of model.
117
+ dim (int): Dimension of model.
79
118
depth (int): Depth of model.
80
119
kernel_size (int, optional): Kernel size. Defaults to 9.
81
120
patch_size (int, optional): Patch size. Defaults to 7.
82
121
n_classes (int, optional): Number of classes. Defaults to 1000.
83
122
act_fn (nn.Module, optional): Activation function. Defaults to nn.GELU().
84
123
stem (nn.Module, optional): You can path different first layer..
85
- stem_ks (int, optional): If stem_ch not 0 - kernel size for adittional layer. Defaults to 1.
86
- bn_1st (bool, optional): If True - BatchNorm befor activation function. Defaults to False.
87
- pre_act (bool, optional): If True - activatin function befor convolution layer. Defaults to False.
124
+ stem_ks (int, optional): If stem_ch not 0 - kernel size for additional layer. Defaults to 1.
125
+ bn_1st (bool, optional): If True - BatchNorm before activation function. Defaults to False.
126
+ pre_act (bool, optional): If True - activation function before convolution layer. Defaults to False.
88
127
init_func (Callable, optional): External function for init model.
89
128
90
129
"""
91
130
if pre_act :
92
131
bn_1st = False
93
132
if stem is None :
94
- stem = ConvLayer (3 , dim , kernel_size = patch_size , stride = patch_size , act_fn = act_fn , bn_1st = bn_1st )
133
+ stem = ConvLayer (
134
+ in_chans ,
135
+ dim ,
136
+ kernel_size = patch_size ,
137
+ stride = patch_size ,
138
+ act_fn = act_fn ,
139
+ bn_1st = bn_1st ,
140
+ )
95
141
96
142
super ().__init__ (
97
143
stem ,
98
- * [nn .Sequential (
99
- Residual (
100
- ConvLayer (dim , dim , kernel_size , act_fn = act_fn ,
101
- groups = dim , padding = "same" , bn_1st = bn_1st , pre_act = pre_act )),
102
- ConvLayer (dim , dim , kernel_size = 1 , act_fn = act_fn , bn_1st = bn_1st , pre_act = pre_act ))
103
- for i in range (depth )],
144
+ * [
145
+ nn .Sequential (
146
+ Residual (
147
+ ConvLayer (
148
+ dim ,
149
+ dim ,
150
+ kernel_size ,
151
+ act_fn = act_fn ,
152
+ groups = dim ,
153
+ padding = "same" ,
154
+ bn_1st = bn_1st ,
155
+ pre_act = pre_act ,
156
+ )
157
+ ),
158
+ ConvLayer (
159
+ dim ,
160
+ dim ,
161
+ kernel_size = 1 ,
162
+ act_fn = act_fn ,
163
+ bn_1st = bn_1st ,
164
+ pre_act = pre_act ,
165
+ ),
166
+ )
167
+ for _ in range (depth )
168
+ ],
104
169
nn .AdaptiveAvgPool2d ((1 , 1 )),
105
170
nn .Flatten (),
106
- nn .Linear (dim , n_classes ))
171
+ nn .Linear (dim , n_classes )
172
+ )
107
173
if init_func is not None : # pragma: no cover
108
174
init_func (self )
0 commit comments