File tree 8 files changed +22
-22
lines changed
src/warmup_scheduler_pytorch
8 files changed +22
-22
lines changed Original file line number Diff line number Diff line change 8
8
9
9
## Description
10
10
11
- A Warmup Scheduler for Pytorch to achieve the warmup learning rate at the beginning of training.
11
+ A Warmup Scheduler for Pytorch to make the warmup learning rate change at the beginning of training.
12
12
13
13
## setup
14
14
15
+ Notice: need to install pytorch>=1.1.0 manually. \
16
+ The official website of pytorch is: https://pytorch.org/
17
+
18
+ Then install as follows:
19
+
15
20
```
16
21
pip install warmup_scheduler_pytorch
17
22
```
@@ -26,7 +31,7 @@ import torch
26
31
from torch.optim import SGD # example
27
32
from torch.optim.lr_scheduler import CosineAnnealingLR # example
28
33
29
- from warmup_scheduler_pytorch.warmup_module import WarmUpScheduler
34
+ from warmup_scheduler_pytorch import WarmUpScheduler
30
35
31
36
model = Model()
32
37
optimizer = SGD(model.parameters(), lr = 0.1 )
Original file line number Diff line number Diff line change @@ -45,7 +45,7 @@ def run():
45
45
epoch_lr [1 ].append (get_lr (optimizer ))
46
46
47
47
# output = model(...)
48
- # loss = loss_fn(output, ... )
48
+ # loss = loss_fn(output, label )
49
49
# loss.backward()
50
50
optimizer .step ()
51
51
optimizer .zero_grad ()
Original file line number Diff line number Diff line change 1
1
[build-system ]
2
- requires = [" setuptools>=42" ]
2
+ requires = [" setuptools>=42.0.0 " ]
3
3
build-backend = " setuptools.build_meta"
Original file line number Diff line number Diff line change @@ -13,7 +13,10 @@ classifiers =
13
13
Intended Audience :: Science/Research
14
14
License :: OSI Approved :: MIT License
15
15
Operating System :: OS Independent
16
- Programming Language :: Python :: 3
16
+ Programming Language :: Python :: 3.6
17
+ Programming Language :: Python :: 3.7
18
+ Programming Language :: Python :: 3.8
19
+ Programming Language :: Python :: 3.9
17
20
Topic :: Scientific/Engineering :: Artificial Intelligence
18
21
19
22
@@ -22,8 +25,8 @@ package_dir =
22
25
= src
23
26
packages = find:
24
27
python_requires = >=3.6
25
- install_requires =
26
- torch >= 1.7.1
28
+ # install_requires=
29
+ # torch >= 1.1.0
27
30
28
31
29
32
[options.packages.find]
Original file line number Diff line number Diff line change 1
1
from .warmup_module import WarmUpScheduler , VERSION
2
2
3
- __all__ = ['__version__' , 'WarmUpScheduler' ]
4
-
5
3
__version__ = VERSION
6
4
5
+ __all__ = ['__version__' , 'WarmUpScheduler' ]
Original file line number Diff line number Diff line change 4
4
"""
5
5
6
6
from torch .optim import Optimizer
7
- from torch .optim .lr_scheduler import _LRScheduler # ignore its error
7
+ from torch .optim .lr_scheduler import _LRScheduler
8
8
9
9
__all__ = ['VERSION' , 'WarmUpScheduler' ]
10
10
11
- VERSION = '0.1.0 '
11
+ VERSION = '0.1.1 '
12
12
13
13
14
14
class WarmUpScheduler (object ):
Original file line number Diff line number Diff line change 1
- from src .warmup_scheduler_pytorch import WarmUpScheduler , __version__
2
- from src .warmup_scheduler_pytorch .warmup_module import VERSION
3
-
4
-
5
- def test_version ():
6
- assert VERSION == __version__
7
-
8
-
9
1
def test_import ():
10
- assert isinstance (WarmUpScheduler , object )
2
+ from src .warmup_scheduler_pytorch import WarmUpScheduler , __version__
3
+ from src .warmup_scheduler_pytorch .warmup_module import VERSION
Original file line number Diff line number Diff line change @@ -77,8 +77,8 @@ def test_warmup_init(self):
77
77
pass
78
78
79
79
def test_warmup_state_dict (self ):
80
- sd = self .warmup_scheduler .state_dict ()
81
- self .warmup_scheduler .load_state_dict (sd )
80
+ state_dict = self .warmup_scheduler .state_dict ()
81
+ self .warmup_scheduler .load_state_dict (state_dict )
82
82
83
83
def test_warmup_get (self ):
84
84
self .warmup_scheduler .get_last_lr ()
You can’t perform that action at this time.
0 commit comments