Skip to content

Commit 32d49d9

Browse files
authored
Merge pull request #122 from ayasyrev/dev
4.2
2 parents c9f17d5 + d71e24f commit 32d49d9

17 files changed

+44
-72
lines changed

noxfile_cov.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import nox
22

33

4-
@nox.session(python=["3.10"])
4+
@nox.session(python=["3.11"])
55
def cov_tests(session: nox.Session) -> None:
66
args = session.posargs or ["--cov"]
77
session.install(".", "pytest", "pytest-cov", "coverage[toml]")

requirements_test.txt

+1-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,2 @@
11
pytest
2-
pytest-cov
3-
coverage[toml]
4-
flake8
5-
nox
2+
pytest-cov

requirements_test_extra.txt

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
coverage[toml]
2+
black
3+
flake8
4+
nox
5+
isort

setup.cfg

+5-2
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,18 @@ long_description_content_type = text/markdown
1010
url = https://github.com/ayasyrev/model_constructor
1111
license = apache2
1212
classifiers =
13-
Programming Language :: Python :: 3
13+
Programming Language :: Python :: 3.8
14+
Programming Language :: Python :: 3.9
15+
Programming Language :: Python :: 3.10
16+
Programming Language :: Python :: 3.11
1417
License :: OSI Approved :: Apache Software License
1518
Operating System :: OS Independent
1619

1720
[options]
1821
package_dir =
1922
= src
2023
packages = find:
21-
python_requires = >=3.7
24+
python_requires = >=3.8, <3.12
2225

2326
[options.packages.find]
2427
where = src

setup_.py

-50
This file was deleted.

src/model_constructor/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .convmixer import ConvMixer
2-
from .model_constructor import ModelConstructor, ModelCfg
2+
from .model_constructor import ModelCfg, ModelConstructor
33
from .version import __version__

src/model_constructor/activations.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# forked from https://github.com/rwightman/pytorch-image-models/timm/models/layers/activations.py
22
import torch
33
from torch import nn as nn
4-
from torch.nn import functional as F
54
from torch.nn import Mish
6-
5+
from torch.nn import functional as F
76

87
__all__ = [
98
"mish",

src/model_constructor/helpers.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,18 @@ def print_set_fields(self) -> None:
140140
else:
141141
print("Nothing changed")
142142

143-
def print_changed_fields(self, show_default: bool = False, separator: str = " | ") -> None:
143+
def print_changed_fields(
144+
self, show_default: bool = False, separator: str = " | "
145+
) -> None:
144146
"""Print fields changed at init."""
145147
if self.changed_fields:
146148
default_value = ""
147149
print("Changed fields:")
148150
for field in self.changed_fields:
149151
if show_default:
150-
default_value = f"{separator}{self._get_str(self.model_fields[field].default)}"
152+
default_value = (
153+
f"{separator}{self._get_str(self.model_fields[field].default)}"
154+
)
151155
print(f"{field}: {self._get_str_value(field)}{default_value}")
152156
else:
153157
print("Nothing changed")

src/model_constructor/model_constructor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import OrderedDict
22
from functools import partial
3-
from typing import Any, Callable, Dict, List, Optional, Union, Type
3+
from typing import Any, Callable, Dict, List, Optional, Type, Union
44

55
from pydantic import field_validator
66
from pydantic_core.core_schema import FieldValidationInfo
@@ -32,7 +32,7 @@
3232
}
3333

3434

35-
nnModule = Union[Type[nn.Module], Callable[[], nn.Module]]
35+
nnModule = Union[Type[nn.Module], Callable[[Any], nn.Module]]
3636

3737

3838
class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):

src/model_constructor/mxresnet.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List, Type
2+
23
from torch import nn
34

45
from .xresnet import XResNet

src/model_constructor/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.4.1"
1+
__version__ = "0.4.2_dev"

src/model_constructor/xresnet.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import nn
55

66
from .blocks import BottleneckBlock
7-
from .helpers import ListStrMod, nn_seq, ModSeq
7+
from .helpers import ListStrMod, ModSeq, nn_seq
88
from .model_constructor import ModelCfg, ModelConstructor
99

1010
__all__ = [
@@ -50,6 +50,11 @@ class XResNet34(XResNet):
5050
layers: List[int] = [3, 4, 6, 3]
5151

5252

53-
class XResNet50(XResNet34):
53+
class XResNet26(XResNet):
5454
block: Type[nn.Module] = BottleneckBlock
5555
block_sizes: List[int] = [256, 512, 1024, 2048]
56+
expansion: int = 4
57+
58+
59+
class XResNet50(XResNet26):
60+
layers: List[int] = [3, 4, 6, 3]

src/model_constructor/yaresnet.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@
1010
from model_constructor.helpers import ModSeq, nn_seq
1111

1212
from .layers import ConvBnAct, get_act
13-
from .model_constructor import ListStrMod, ModelConstructor, ModelCfg
13+
from .model_constructor import ListStrMod, ModelCfg, ModelConstructor
1414
from .xresnet import xresnet_stem
1515

16-
1716
__all__ = [
1817
"YaBasicBlock",
1918
"YaBottleneckBlock",
@@ -216,6 +215,11 @@ class YaResNet34(YaResNet):
216215
layers: List[int] = [3, 4, 6, 3]
217216

218217

219-
class YaResNet50(YaResNet34):
218+
class YaResNet26(YaResNet):
220219
block: Type[nn.Module] = YaBottleneckBlock
221220
block_sizes: List[int] = [256, 512, 1024, 2048]
221+
expansion: int = 4
222+
223+
224+
class YaResNet50(YaResNet26):
225+
layers: List[int] = [3, 4, 6, 3]

tests/test_blocks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import torch
55
from torch import nn
66

7-
from model_constructor.layers import SEModule, SimpleSelfAttention
87
from model_constructor.blocks import BasicBlock, BottleneckBlock
8+
from model_constructor.layers import SEModule, SimpleSelfAttention
99
from model_constructor.yaresnet import YaBasicBlock, YaBottleneckBlock
1010

1111
from .parameters import ids_fn

tests/test_helpers.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ def test_cfg_repr_print(capsys: CaptureFixture[str]):
7070
cfg.print_set_fields()
7171
out = capsys.readouterr().out
7272
assert out == "Nothing changed\n"
73-
assert "name" in cfg.model_fields_set # pylint: disable=E1135:unsupported-membership-test
73+
assert (
74+
"name" in cfg.model_fields_set
75+
) # pylint: disable=E1135:unsupported-membership-test
7476
cfg = Cfg2(int_value=0)
7577
cfg.print_set_fields()
7678
out = capsys.readouterr().out

tests/test_models.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Type
2+
23
import pytest
34
import torch
45
from torch import nn

tests/test_models_universal_blocks.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Type
2+
23
import pytest
34
import torch
45
from torch import nn

0 commit comments

Comments
 (0)