Skip to content

Commit 0823635

Browse files
authored
Merge pull request #92 from ayasyrev/dev_0.3.3
0.3.3
2 parents 57f9fc2 + 6928297 commit 0823635

21 files changed

+1288
-403
lines changed

docs/overrides/partials/copyright.html

-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
{#-
2-
This file was automatically generated - do not edit
3-
-#}
41
<div class="md-copyright">
52
{% if config.copyright %}
63
<div class="md-copyright__highlight">

mkdocs.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@ extra:
2424
analytics:
2525
provider: google
2626
property: G-0F3FK713C2
27+
copyright: Copyright &copy; 2020-2023 Andrei Yasyrev.

src/model_constructor/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from model_constructor.convmixer import ConvMixer # noqa F401
22
from model_constructor.model_constructor import (
33
ModelConstructor,
4-
ResBlock,
54
ModelCfg,
65
) # noqa F401
76

src/model_constructor/base_constructor.py

+80-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,31 @@
1-
import torch.nn as nn
1+
"""First version of constructor.
2+
"""
3+
# Used in examples.
4+
# first implementation of xresnet - inspired by fastai version.
25
from collections import OrderedDict
3-
from .layers import ConvLayer, Noop, Flatten
6+
from functools import partial
47

8+
import torch.nn as nn
59

6-
__all__ = ['act_fn', 'Stem', 'DownsampleBlock', 'BasicBlock', 'Bottleneck', 'BasicLayer', 'Body', 'Head', 'init_model',
7-
'Net']
10+
from .layers import ConvLayer, Flatten, Noop
11+
12+
__all__ = [
13+
"act_fn",
14+
"Stem",
15+
"DownsampleBlock",
16+
"BasicBlock",
17+
"Bottleneck",
18+
"BasicLayer",
19+
"Body",
20+
"Head",
21+
"init_model",
22+
"Net",
23+
"DownsampleLayer",
24+
"XResBlock",
25+
"xresnet18",
26+
"xresnet34",
27+
"xresnet50",
28+
]
829

930

1031
act_fn = nn.ReLU(inplace=True)
@@ -162,3 +183,58 @@ def __init__(self, stem=Stem,
162183
('head', head(body_out * expansion, num_classes, **kwargs))
163184
]))
164185
self.init_model(self)
186+
187+
188+
# xresnet from fastai
189+
190+
191+
class DownsampleLayer(nn.Sequential):
192+
"""Downsample layer for Xresnet Resblock"""
193+
194+
def __init__(self, conv_layer, ni, nf, stride, act,
195+
pool=nn.AvgPool2d(2, ceil_mode=True), pool_1st=True,
196+
**kwargs):
197+
layers = [] if stride == 1 else [('pool', pool)]
198+
layers += [] if ni == nf else [('idconv', conv_layer(ni, nf, 1, act=act, **kwargs))]
199+
if not pool_1st:
200+
layers.reverse()
201+
super().__init__(OrderedDict(layers))
202+
203+
204+
class XResBlock(nn.Module):
205+
'''XResnet block'''
206+
207+
def __init__(self, ni, nh, expansion=1, stride=1, zero_bn=True,
208+
conv_layer=ConvLayer, act_fn=act_fn, **kwargs):
209+
super().__init__()
210+
nf, ni = nh * expansion, ni * expansion
211+
layers = [('conv_0', conv_layer(ni, nh, 3, stride=stride, act_fn=act_fn, **kwargs)),
212+
('conv_1', conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, act_fn=act_fn, **kwargs))
213+
] if expansion == 1 else [
214+
('conv_0', conv_layer(ni, nh, 1, act_fn=act_fn, **kwargs)),
215+
('conv_1', conv_layer(nh, nh, 3, stride=stride, act_fn=act_fn, **kwargs)),
216+
('conv_2', conv_layer(nh, nf, 1, zero_bn=zero_bn, act=False, act_fn=act_fn, **kwargs))
217+
]
218+
self.convs = nn.Sequential(OrderedDict(layers))
219+
self.identity = DownsampleLayer(conv_layer, ni, nf, stride,
220+
act=False, act_fn=act_fn, **kwargs) if ni != nf or stride == 2 else Noop()
221+
self.merge = Noop()
222+
self.act_fn = act_fn
223+
224+
def forward(self, x):
225+
return self.act_fn(self.merge(self.convs(x) + self.identity(x)))
226+
227+
228+
def xresnet18(**kwargs):
229+
"""Constructs xresnet18 model. """
230+
return Net(stem_sizes=[32, 32], block=XResBlock, blocks=[2, 2, 2, 2], expansion=1, **kwargs)
231+
232+
233+
def xresnet34(**kwargs):
234+
"""Constructs xresnet34 model. """
235+
return Net(stem_sizes=[32, 32], block=XResBlock, blocks=[3, 4, 6, 3], expansion=1, **kwargs)
236+
237+
238+
def xresnet50(**kwargs):
239+
"""Constructs xresnet50 model. """
240+
return Net(stem_sizes=[32, 32], block=XResBlock, blocks=[3, 4, 6, 3], expansion=4, **kwargs)

src/model_constructor/convmixer.py

+108-42
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,49 @@
33
# Adopted from https://github.com/tmp-iclr/convmixer
44
# Home for convmixer: https://github.com/locuslab/convmixer
55
from collections import OrderedDict
6-
from typing import Callable
6+
from typing import Callable, List, Optional, Union
7+
78
import torch.nn as nn
9+
from torch import TensorType
810

911

1012
class Residual(nn.Module):
11-
def __init__(self, fn):
13+
def __init__(self, fn: Callable[[TensorType], TensorType]):
1214
super().__init__()
1315
self.fn = fn
1416

15-
def forward(self, x):
17+
def forward(self, x: TensorType) -> TensorType:
1618
return self.fn(x) + x
1719

1820

1921
# As original version, act_fn as argument.
20-
def ConvMixerOriginal(dim, depth,
21-
kernel_size=9, patch_size=7, n_classes=1000,
22-
act_fn=nn.GELU()):
22+
def ConvMixerOriginal(
23+
dim: int,
24+
depth: int,
25+
kernel_size: int = 9,
26+
patch_size: int = 7,
27+
n_classes: int = 1000,
28+
act_fn: nn.Module = nn.GELU(),
29+
):
2330
return nn.Sequential(
2431
nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
2532
act_fn,
2633
nn.BatchNorm2d(dim),
27-
*[nn.Sequential(
28-
Residual(nn.Sequential(
29-
nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
34+
*[
35+
nn.Sequential(
36+
Residual(
37+
nn.Sequential(
38+
nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
39+
act_fn,
40+
nn.BatchNorm2d(dim),
41+
)
42+
),
43+
nn.Conv2d(dim, dim, kernel_size=1),
3044
act_fn,
31-
nn.BatchNorm2d(dim)
32-
)),
33-
nn.Conv2d(dim, dim, kernel_size=1),
34-
act_fn,
35-
nn.BatchNorm2d(dim)
36-
) for i in range(depth)],
45+
nn.BatchNorm2d(dim),
46+
)
47+
for _i in range(depth)
48+
],
3749
nn.AdaptiveAvgPool2d((1, 1)),
3850
nn.Flatten(),
3951
nn.Linear(dim, n_classes)
@@ -43,15 +55,35 @@ def ConvMixerOriginal(dim, depth,
4355
class ConvLayer(nn.Sequential):
4456
"""Basic conv layers block"""
4557

46-
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
47-
act_fn=nn.GELU(), padding=0, groups=1,
48-
bn_1st=False, pre_act=False):
58+
def __init__(
59+
self,
60+
in_channels: int,
61+
out_channels: int,
62+
kernel_size: Union[int, tuple[int, int]],
63+
stride: int = 1,
64+
act_fn: nn.Module = nn.GELU(),
65+
padding: Union[int, str] = 0,
66+
groups: int = 1,
67+
bn_1st: bool = False,
68+
pre_act: bool = False,
69+
):
4970

50-
conv_layer = [('conv', nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
51-
padding=padding, groups=groups))]
52-
act_bn = [
53-
('act_fn', act_fn),
54-
('bn', nn.BatchNorm2d(out_channels))
71+
conv_layer: List[tuple[str, nn.Module]] = [
72+
(
73+
"conv",
74+
nn.Conv2d(
75+
in_channels,
76+
out_channels,
77+
kernel_size,
78+
stride=stride,
79+
padding=padding,
80+
groups=groups,
81+
),
82+
)
83+
]
84+
act_bn: List[tuple[str, nn.Module]] = [
85+
("act_fn", act_fn),
86+
("bn", nn.BatchNorm2d(out_channels)),
5587
]
5688
if bn_1st:
5789
act_bn.reverse()
@@ -64,45 +96,79 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
6496

6597

6698
class ConvMixer(nn.Sequential):
67-
68-
def __init__(self, dim: int, depth: int,
69-
kernel_size: int = 9, patch_size: int = 7, n_classes: int = 1000,
70-
act_fn: nn.Module = nn.GELU(),
71-
stem: nn.Module = None,
72-
bn_1st: bool = False, pre_act: bool = False,
73-
init_func: Callable = None):
99+
def __init__(
100+
self,
101+
dim: int,
102+
depth: int,
103+
kernel_size: int = 9,
104+
patch_size: int = 7,
105+
n_classes: int = 1000,
106+
act_fn: nn.Module = nn.GELU(),
107+
stem: Optional[nn.Module] = None,
108+
in_chans: int = 3,
109+
bn_1st: bool = False,
110+
pre_act: bool = False,
111+
init_func: Optional[Callable[[nn.Module], None]] = None,
112+
):
74113
"""ConvMixer constructor.
75114
Adopted from https://github.com/tmp-iclr/convmixer
76115
77116
Args:
78-
dim (int): Dimention of model.
117+
dim (int): Dimension of model.
79118
depth (int): Depth of model.
80119
kernel_size (int, optional): Kernel size. Defaults to 9.
81120
patch_size (int, optional): Patch size. Defaults to 7.
82121
n_classes (int, optional): Number of classes. Defaults to 1000.
83122
act_fn (nn.Module, optional): Activation function. Defaults to nn.GELU().
84123
stem (nn.Module, optional): You can path different first layer..
85-
stem_ks (int, optional): If stem_ch not 0 - kernel size for adittional layer. Defaults to 1.
86-
bn_1st (bool, optional): If True - BatchNorm befor activation function. Defaults to False.
87-
pre_act (bool, optional): If True - activatin function befor convolution layer. Defaults to False.
124+
stem_ks (int, optional): If stem_ch not 0 - kernel size for additional layer. Defaults to 1.
125+
bn_1st (bool, optional): If True - BatchNorm before activation function. Defaults to False.
126+
pre_act (bool, optional): If True - activation function before convolution layer. Defaults to False.
88127
init_func (Callable, optional): External function for init model.
89128
90129
"""
91130
if pre_act:
92131
bn_1st = False
93132
if stem is None:
94-
stem = ConvLayer(3, dim, kernel_size=patch_size, stride=patch_size, act_fn=act_fn, bn_1st=bn_1st)
133+
stem = ConvLayer(
134+
in_chans,
135+
dim,
136+
kernel_size=patch_size,
137+
stride=patch_size,
138+
act_fn=act_fn,
139+
bn_1st=bn_1st,
140+
)
95141

96142
super().__init__(
97143
stem,
98-
*[nn.Sequential(
99-
Residual(
100-
ConvLayer(dim, dim, kernel_size, act_fn=act_fn,
101-
groups=dim, padding="same", bn_1st=bn_1st, pre_act=pre_act)),
102-
ConvLayer(dim, dim, kernel_size=1, act_fn=act_fn, bn_1st=bn_1st, pre_act=pre_act))
103-
for i in range(depth)],
144+
*[
145+
nn.Sequential(
146+
Residual(
147+
ConvLayer(
148+
dim,
149+
dim,
150+
kernel_size,
151+
act_fn=act_fn,
152+
groups=dim,
153+
padding="same",
154+
bn_1st=bn_1st,
155+
pre_act=pre_act,
156+
)
157+
),
158+
ConvLayer(
159+
dim,
160+
dim,
161+
kernel_size=1,
162+
act_fn=act_fn,
163+
bn_1st=bn_1st,
164+
pre_act=pre_act,
165+
),
166+
)
167+
for _ in range(depth)
168+
],
104169
nn.AdaptiveAvgPool2d((1, 1)),
105170
nn.Flatten(),
106-
nn.Linear(dim, n_classes))
171+
nn.Linear(dim, n_classes)
172+
)
107173
if init_func is not None: # pragma: no cover
108174
init_func(self)

src/model_constructor/helpers.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from collections import OrderedDict
2+
from typing import Iterable
3+
4+
from torch import nn
5+
6+
7+
def nn_seq(list_of_tuples: Iterable[tuple[str, nn.Module]]) -> nn.Sequential:
8+
"""return nn.Sequential from OrderedDict from list of tuples"""
9+
return nn.Sequential(OrderedDict(list_of_tuples)) #

0 commit comments

Comments
 (0)