Skip to content

Commit 349b81b

Browse files
committed
initial commit
0 parents  commit 349b81b

File tree

182 files changed

+7701
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

182 files changed

+7701
-0
lines changed

.gitignore

+147
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# user added
2+
*.local
3+
.vscode/**
4+
5+
# Byte-compiled / optimized / DLL files
6+
__pycache__/
7+
**/__pycache__/
8+
*.py[cod]
9+
*$py.class
10+
11+
# C extensions
12+
*.so
13+
14+
# Distribution / packaging
15+
.Python
16+
build/
17+
develop-eggs/
18+
dist/
19+
downloads/
20+
eggs/
21+
.eggs/
22+
lib/
23+
lib64/
24+
parts/
25+
sdist/
26+
var/
27+
wheels/
28+
pip-wheel-metadata/
29+
share/python-wheels/
30+
*.egg-info/
31+
.installed.cfg
32+
*.egg
33+
MANIFEST
34+
35+
# PyInstaller
36+
# Usually these files are written by a python script from a template
37+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
38+
*.manifest
39+
*.spec
40+
41+
# Installer logs
42+
pip-log.txt
43+
pip-delete-this-directory.txt
44+
45+
# Unit test / coverage reports
46+
htmlcov/
47+
.tox/
48+
.nox/
49+
.coverage
50+
.coverage.*
51+
.cache
52+
nosetests.xml
53+
coverage.xml
54+
*.cover
55+
*.py,cover
56+
.hypothesis/
57+
.pytest_cache/
58+
59+
# Translations
60+
*.mo
61+
*.pot
62+
63+
# Django stuff:
64+
*.log
65+
local_settings.py
66+
db.sqlite3
67+
db.sqlite3-journal
68+
69+
# Flask stuff:
70+
instance/
71+
.webassets-cache
72+
73+
# Scrapy stuff:
74+
.scrapy
75+
76+
# Sphinx documentation
77+
docs/_build/
78+
79+
# PyBuilder
80+
target/
81+
82+
# Jupyter Notebook
83+
.ipynb_checkpoints
84+
**/.ipynb_checkpoints
85+
86+
# IPython
87+
profile_default/
88+
ipython_config.py
89+
90+
# pyenv
91+
.python-version
92+
93+
# pipenv
94+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
96+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
97+
# install all needed dependencies.
98+
#Pipfile.lock
99+
100+
# celery beat schedule file
101+
celerybeat-schedule
102+
103+
# SageMath parsed files
104+
*.sage.py
105+
106+
# Environments
107+
.env
108+
.venv
109+
env/
110+
venv/
111+
ENV/
112+
env.bak/
113+
venv.bak/
114+
115+
# Spyder project settings
116+
.spyderproject
117+
.spyproject
118+
119+
# Rope project settings
120+
.ropeproject
121+
122+
# mkdocs documentation
123+
/site
124+
125+
# mypy
126+
.mypy_cache/
127+
.dmypy.json
128+
dmypy.json
129+
130+
# Pyre type checker
131+
.pyre/
132+
133+
134+
# USER DEFINED
135+
zzz*
136+
__old__
137+
outputs/
138+
wandb/
139+
runs/
140+
*.ipynb
141+
debug.py
142+
reproducibility.py
143+
logs/*
144+
data/**
145+
default/**
146+
debug/**
147+
configs/wandb/_defaults.yaml

.gitmodules

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "site"]
2+
path = site
3+
url = https://github.com/nik-dim/sequel-site

LICENSE

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
Copyright (c) 2023 "NAME"
2+
3+
Permission is hereby granted, free of charge, to any person obtaining a copy
4+
of this software and associated documentation files (the "Software"), to deal
5+
in the Software without restriction, including without limitation the rights
6+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7+
copies of the Software, and to permit persons to whom the Software is
8+
furnished to do so, subject to the following conditions:
9+
10+
The above copyright notice and this permission notice shall be included in all
11+
copies or substantial portions of the Software.
12+
13+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19+
SOFTWARE.

README.md

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Sequel: A Continual Learning Library in PyTorch and JAX
2+
3+
The goal of this library is to provide a simple and easy to use framework for continual learning. The library is written in PyTorch and JAX and provides a simple interface to run experiments. The library is still in development and we are working on adding more algorithms and datasets.
4+
5+
- Documetation: https://nik-dim.github.io/sequel-site/
6+
- Reproducibility Board: https://nik-dim.github.io/sequel-site/reproducibility/
7+
- Weights&Biases: https://wandb.ai/nikdim/SequeL
8+
## Installation
9+
10+
The library can be installed via pip:
11+
```bash
12+
pip install sequel-core
13+
```
14+
15+
Alternatively, you can install the library from source:
16+
```bash
17+
git clone https://github.com/nik-dim/sequel.git
18+
python3 -m build
19+
```
20+
21+
or use the library by cloning the repository. In order to use the library, you need to install the dependencies. This can be done via the `requirements.txt` file. We recommend to use a conda environment for this. The following commands will create a conda environment with the required packages and activate it:
22+
```bash
23+
# create the conda environment
24+
conda create -n sequel -y python=3.10 cuda cudatoolkit cuda-nvcc -c nvidia -c anaconda
25+
conda activate sequel
26+
27+
# install all required packages
28+
pip install -r requirements.txt
29+
30+
# Optional: Depending on the machine, the next command might be needed to enable CUDA support for GPUs
31+
pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
32+
```
33+
34+
35+
## Run an experiment
36+
37+
For some examples, you can modify the `example_pytorch.py` and `example_jax.py` files, or run:
38+
```bash
39+
# example experiment on PyTorch
40+
python example_pytorch.py
41+
42+
# ...and in JAX
43+
python example_jax.py
44+
```
45+
46+
Experiments are located in the `examples/` directory in `configs`. In order to run an experiment you simply do the following:
47+
48+
```bash
49+
python main.py +experiment=EXPERIMENT_DIR/EXPERRIMENT
50+
51+
# examples
52+
python main.py +examples=ewc_rotatedmnist mode=pytorch # or mode=jax
53+
python main.py +examples=mcsgd_rotatedmnist mode=pytorch # or mode=jax
54+
```
55+
56+
In order to create your own experiment you follow the template of the experiments in `configs/examples/`. You override the defaults so that e.g. another algorithm is selected and you specify the training details. To run multiple experiments with different configs, the `--multirun` flag of [Hydra](https://hydra.cc/docs) can be used.
57+
For instance:
58+
```bash
59+
python main.py --multirun +examples=ewc_rotatedmnist \
60+
mode=pytorch optimizer.lr=0.01,0.001 \
61+
benchmark.batch_size=128,256 \
62+
training.epochs_per_task=1 # online setting
63+
```

_main_jax.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import logging
2+
3+
import hydra
4+
from omegaconf import DictConfig, OmegaConf
5+
6+
from sequel.algos.jax import ALGOS
7+
from sequel.backbones.jax import select_backbone, select_optimizer
8+
from sequel.benchmarks import select_benchmark
9+
from sequel.utils.callbacks.metrics.jax_metric_callback import StandardMetricCallback
10+
from sequel.utils.callbacks.tqdm_callback import TqdmCallback
11+
from sequel.utils.loggers.logging import install_logging
12+
from sequel.utils.loggers.wandb_logger import WandbLogger
13+
from sequel.utils.utils import set_seed
14+
15+
16+
def without(d, key):
17+
new_d = d.copy()
18+
new_d.pop(key)
19+
return new_d
20+
21+
22+
@hydra.main(config_path="configs", config_name="config", version_base="1.1")
23+
def my_app(config: DictConfig) -> None:
24+
install_logging()
25+
logging.info("The experiment config is:\n" + OmegaConf.to_yaml(config))
26+
logger = WandbLogger(config)
27+
28+
set_seed(config.seed)
29+
30+
mc = StandardMetricCallback()
31+
tq = TqdmCallback()
32+
# initialize benchmark (e.g. SplitMNIST)
33+
benchmark = select_benchmark(config.benchmark)
34+
logging.info(benchmark)
35+
36+
# initialize backbone model (e.g. a CNN, MLP)
37+
backbone = select_backbone(config)
38+
logging.info(backbone)
39+
40+
optimizer = select_optimizer(config, backbone)
41+
42+
algo = ALGOS[config.algo.name.lower()](
43+
**without(dict(config.algo), "name"),
44+
backbone=backbone,
45+
benchmark=benchmark,
46+
optimizer=optimizer,
47+
callbacks=[mc, tq],
48+
loggers=[logger],
49+
)
50+
logging.info(algo)
51+
52+
# start the learning process!
53+
algo.fit(epochs=config.training.epochs_per_task)
54+
55+
56+
if __name__ == "__main__":
57+
my_app()

_main_pytorch.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import logging
2+
3+
import hydra
4+
from omegaconf import DictConfig, OmegaConf
5+
from sequel.utils.loggers.logging import install_logging
6+
from sequel.utils.callbacks.metrics.pytorch_metric_callback import StandardMetricCallback
7+
from sequel.benchmarks import select_benchmark
8+
9+
from sequel.backbones.pytorch import select_backbone, select_optimizer
10+
from sequel.utils.callbacks.tqdm_callback import TqdmCallback
11+
from sequel.utils.loggers.wandb_logger import WandbLogger
12+
13+
from sequel.algos.pytorch import ALGOS
14+
from sequel.utils.utils import set_seed
15+
16+
17+
def without(d, key):
18+
new_d = d.copy()
19+
new_d.pop(key)
20+
return new_d
21+
22+
23+
@hydra.main(config_path="configs", config_name="config", version_base="1.1")
24+
def my_app(config: DictConfig) -> None:
25+
install_logging()
26+
logging.info("The experiment config is:\n" + OmegaConf.to_yaml(config))
27+
logger = WandbLogger(config)
28+
29+
set_seed(config.seed)
30+
31+
mc = StandardMetricCallback()
32+
tq = TqdmCallback()
33+
34+
# initialize benchmark (e.g. SplitMNIST)
35+
benchmark = select_benchmark(config.benchmark)
36+
logging.info(benchmark)
37+
38+
# initialize backbone model (e.g. a CNN, MLP)
39+
backbone = select_backbone(config)
40+
logging.info(backbone)
41+
42+
optimizer = select_optimizer(config, backbone)
43+
44+
algo = ALGOS[config.algo.name.lower()](
45+
**without(dict(config.algo), "name"),
46+
backbone=backbone,
47+
benchmark=benchmark,
48+
optimizer=optimizer,
49+
callbacks=[mc, tq],
50+
loggers=[logger],
51+
)
52+
logging.info(algo)
53+
54+
# start the learning process!
55+
algo.fit(epochs=config.training.epochs_per_task)
56+
57+
58+
if __name__ == "__main__":
59+
my_app()

configs/algo/agem.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
name: agem
2+
per_task_memory_samples: 250
3+
memory_batch_size: 256
4+
memory_group_by: task

configs/algo/der++.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
name: der
2+
alpha: 4
3+
memory_size: 1000
4+
5+
beta: 4

configs/algo/der.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
name: der
2+
alpha: 4
3+
memory_size: 1000

configs/algo/er.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
name: er
2+
per_task_memory_samples: 100
3+
memory_batch_size: 100
4+
memory_group_by: class

configs/algo/ewc.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
name: ewc
2+
ewc_lambda: 1.0

configs/algo/icarl.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
name: icarl
2+
memory_size: 2000

configs/algo/joint.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
name: joint

configs/algo/kcl.yaml

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
name: kcl
2+
core_size: 20
3+
d_rn_f: 2048
4+
kernel_type: rff
5+
lmd: 0.1
6+
lr_decay: 0.95
7+
tau: 0.01

0 commit comments

Comments
 (0)