Skip to content

Commit 01c2e0e

Browse files
committed
Add files
1 parent b485ad1 commit 01c2e0e

File tree

12 files changed

+2221
-2
lines changed

12 files changed

+2221
-2
lines changed

README.md

100644100755
Lines changed: 255 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,255 @@
1-
# MambaRoll
2-
Official implementation of MambaRoll: A Physics-Driven Autoregressive State Space Model for Medical Image Reconstruction
1+
<hr>
2+
<h1 align="center">
3+
MambaRoll <br>
4+
<sub>Physics-Driven Autoregressive State Space Models for Medical Image Reconstruction</sub>
5+
</h1>
6+
7+
<div align="center">
8+
<a href="https://bilalkabas.github.io/" target="_blank">Bilal&nbsp;Kabas</a><sup>1,2</sup> &ensp; <b>&middot;</b> &ensp;
9+
<a href="https://github.com/fuat-arslan" target="_blank">Fuat&nbsp;Arslan</a><sup>1,2</sup> &ensp; <b>&middot;</b> &ensp;
10+
<a href="https://github.com/Valiyeh" target="_blank">Valiyeh&nbsp;A. Nezhad</a><sup>1,2</sup> &ensp; <b>&middot;</b> &ensp;
11+
<a href="https://scholar.google.com/citations?hl=en&user=_SujLxcAAAAJ" target="_blank">Saban&nbsp;Ozturk</a><sup>1,2</sup> &ensp; <b>&middot;</b> &ensp;
12+
<a href="https://kilyos.ee.bilkent.edu.tr/~saritas/" target="_blank">Emine U.&nbsp;Saritas</a><sup>1,2</sup> &ensp; <b>&middot;</b> &ensp;
13+
<a href="https://kilyos.ee.bilkent.edu.tr/~cukur/" target="_blank">Tolga&nbsp;Çukur</a><sup>1,2</sup> &ensp;
14+
15+
<span></span>
16+
17+
<sup>1</sup>Bilkent University &emsp; <sup>2</sup>UMRAM <br>
18+
</div>
19+
<hr>
20+
21+
<h3 align="center">[<a href="https://arxiv.org/abs/2412.09331">arXiv</a>]</h3>
22+
23+
Official PyTorch implementation of **MambaRoll**, a novel physics-driven autoregressive state space model for enhanced fidelity in medical image reconstruction. In each cascade of an unrolled architecture, MambaRoll employs an autoregressive framework based on physics-driven state space modules (PSSM), where PSSMs efficiently aggregate contextual features at a given spatial scale while maintaining fidelity to acquired data, and autoregressive prediction of next-scale feature maps from earlier spatial scales enhance capture of multi-scale contextual features
24+
25+
<p align="center">
26+
<img src="figures/architecture.png" alt="architecture">
27+
</p>
28+
29+
## ⚙️ Installation
30+
31+
This repository has been developed and tested with `CUDA 12.2` and `Python 3.12`. Below commands create a conda environment with required packages. Make sure conda is installed.
32+
33+
```
34+
conda env create --file requirements.yaml
35+
conda activate mambaroll
36+
```
37+
38+
<details>
39+
<summary>[Optional] Setting Up Faster and Memory-efficient Radon Transform</summary><br>
40+
41+
We use a faster (over 100x) and memory-efficient (~4.5x) implementation of Radon transform ([torch-radon](https://github.com/matteo-ronchetti/torch-radon)). To install, run commands below within `mambaroll` conda environment.
42+
43+
```
44+
git clone https://github.com/matteo-ronchetti/torch-radon.git
45+
cd torch-radon
46+
python setup.py install
47+
```
48+
49+
</details>
50+
51+
## 🗂️ Prepare dataset
52+
53+
MambaRoll supports reconstructions for MRI and CT modalities. Therefore, we have two dataset classes: (1) `MRIDataset` and (2) `CTDataset` in `datasets.py`.
54+
55+
### 1. MRI dataset folder structure
56+
57+
MRI dataset has subfolders for each undersampling rate, e.g. us4x, us8x, etc. There is a separate `.npz` file for each contrast.
58+
59+
<details>
60+
<summary>Details for npz files</summary><br>
61+
62+
A `<contrast>.npz` file has the following keys:
63+
64+
| Variable key | Description | Shape |
65+
|-----------------|-------------------------------------------|-------------------------------------------|
66+
| `image_fs` | Coil-combined fully-sampled MR image. | n_slices x 1 x height x width |
67+
| `image_us` | Multi-coil undersampled MR image. | n_slices x n_coils x height x width |
68+
| `us_masks` | K-space undersampling masks. | n_slices x 1 x height x width |
69+
| `coilmaps` | Coil sensitivity maps. | n_slices x n_coils x height x width |
70+
| `subject_ids` | Corresponding subject ID for each slice. | n_slices |
71+
| `us_factor` | Undersampling factor. | (Single integer value) |
72+
73+
</details>
74+
75+
```
76+
fastMRI/
77+
├── us4x/
78+
│ ├── train/
79+
│ │ ├── T1.npz
80+
│ │ ├── T2.npz
81+
│ │ └── FLAIR.npz
82+
│ ├── test/
83+
│ │ ├── T1.npz
84+
│ │ ├── T2.npz
85+
│ │ └── FLAIR.npz
86+
│ └── val/
87+
│ ├── T1.npz
88+
│ ├── T2.npz
89+
│ └── FLAIR.npz
90+
├── us8x/
91+
│ ├── train/...
92+
│ ├── test/...
93+
│ └── val/...
94+
├── ...
95+
```
96+
97+
98+
99+
### 2. CT dataset folder structure
100+
101+
Each split in CT dataset contains images with different undersampling rates.
102+
103+
<details>
104+
<summary>Details for npz files</summary><br>
105+
106+
`image_fs.npz` files have the fully-sampled data with the following key:
107+
108+
| Variable key | Description | Shape |
109+
|-----------------|-------------------------------------------|-------------------------------------------|
110+
| `image_fs` | Fully-sampled CT image. | n_slices x 1 x height x width |
111+
112+
A `us<us_factor>x.npz` file has the following keys:
113+
114+
| Variable key | Description | Shape |
115+
|-----------------------|----------------------------------------------------------------------------------------------------------------|---------------------------------------------------|
116+
| `image_us` | Undersampled CT image. | n_slices x 1 x height x width |
117+
| `sinogram_us` | Corresponding sinograms for undersampled CTs. | n_slices x 1 x detector_positions x n_projections |
118+
| `projection_angles` | Projection angles at which the Radon transform performed on fully-sampled images to obtain undersampled ones. | n_slices x n_projections |
119+
| `subject_ids` | Corresponding subject ID for each slice. | n_slices |
120+
| `us_factor` | Undersampling factor. | (Single integer value) |
121+
122+
</details>
123+
124+
```
125+
lodopab-ct/
126+
├── train/
127+
│ ├── image_fs.npz
128+
│ ├── us4x.npz
129+
│ └── us6x.npz
130+
├── test/
131+
│ ├── image_fs.npz
132+
│ ├── us4x.npz
133+
│ └── us6x.npz
134+
└── val/
135+
├── image_fs.npz
136+
├── us4x.npz
137+
└── us6x.npz
138+
```
139+
140+
141+
142+
## 🏃 Training
143+
144+
Run the following command to start/resume training. Model checkpoints are saved under `logs/$EXP_NAME/MambaRoll/checkpoints` directory, and sample validation images are saved under `logs/$EXP_NAME/MambaRoll/val_samples`. The script supports both single and multi-GPU training. By default, it runs on a single GPU. To enable multi-GPU training, set `--trainer.devices` argument to the list of devices, e.g. `0,1,2,3`. Be aware that multi-GPU training may lead to convergence issues. Therefore, it is only recommended during inference/testing.
145+
146+
```
147+
python main.py fit \
148+
--config $CONFIG_PATH \
149+
--trainer.logger.name $EXP_NAME \
150+
--model.mode $MODE \
151+
--data.dataset_dir $DATA_DIR \
152+
--data.contrast $CONTRAST \
153+
--data.us_factor $US_FACTOR \
154+
--data.train_batch_size $BS_TRAIN \
155+
--data.val_batch_size $BS_VAL \
156+
[--trainer.max_epoch $N_EPOCHS] \
157+
[--ckpt_path $CKPT_PATH] \
158+
[--trainer.devices $DEVICES]
159+
160+
```
161+
162+
<details>
163+
<summary>Example Commands</summary>
164+
165+
MRI reconstruction using fastMRI dataset:
166+
167+
```
168+
python main.py fit \
169+
--config configs/config_fastmri.yaml \
170+
--trainer.logger.name fastmri_t1_us8x \
171+
--data.dataset_dir ../datasets/fastMRI \
172+
--data.contrast T1 \
173+
--data.us_factor 8 \
174+
--data.train_batch_size 1 \
175+
--data.val_batch_size 16 \
176+
--trainer.devices [0]
177+
```
178+
179+
CT reconstruction using [LoDoPaB-CT](https://zenodo.org/records/3384092) dataset:
180+
181+
```
182+
python main.py fit \
183+
--config configs/config_ct.yaml \
184+
--trainer.logger.name ct_us4x \
185+
--data.dataset_dir ../datasets/lodopab-ct/ \
186+
--data.us_factor 4 \
187+
--data.train_batch_size 1 \
188+
--data.val_batch_size 16 \
189+
--trainer.devices [0]
190+
```
191+
</details>
192+
193+
### Argument descriptions
194+
195+
| Argument | Description |
196+
|-----------------------------|--------------------------------------------------------------------------------------------------------------------------------|
197+
| `--config` | Config file path. Available config files: 'configs/config_fastmri.yaml' and 'configs/config_ct.yaml' |
198+
| `--trainer.logger.name` | Experiment name. |
199+
| `--model.mode` | Mode depending on data modality. Options: 'mri', 'ct'. |
200+
| `--data.dataset_dir` | Data set directory. |
201+
| `--data.contrast` | Source contrast, e.g. 'T1', 'T2', ... for MRI. Should match the folder name for that contrast. |
202+
| `--data.us_factor` | Undersampling factor, e.g 4, 8. |
203+
| `--data.train_batch_size` | Train set batch size. |
204+
| `--data.val_batch_size` | Validation set batch size. |
205+
| `--trainer.max_epoch` | [Optional] Number of training epochs (default: 50). |
206+
| `--ckpt_path` | [Optional] Model checkpoint path to resume training. |
207+
| `--trainer.devices` | [Optional] Device or list of devices. For multi-GPU set to the list of device ids, e.g `0,1,2,3` (default: `[0]`). |
208+
209+
210+
## 🧪 Testing
211+
212+
Run the following command to start testing. The predicted images are saved under `logs/$EXP_NAME/MambaRoll/test_samples` directory. By default, the script runs on a single GPU. To enable multi-GPU testing, set `--trainer.devices` argument to the list of devices, e.g. `0,1,2,3`.
213+
214+
```
215+
python main.py test \
216+
--config $CONFIG_PATH \
217+
--model.mode $MODE \
218+
--data.dataset_dir $DATA_DIR \
219+
--data.contrast $CONTRAST \
220+
--data.us_factor $US_FACTOR \
221+
--data.test_batch_size $BS_TEST \
222+
--ckpt_path $CKPT_PATH
223+
```
224+
225+
### Argument descriptions
226+
227+
Some arguments are common to both training and testing and are not listed here. For details on those arguments, please refer to the training section.
228+
229+
| Argument | Description |
230+
|-----------------------------|--------------------------------------------|
231+
| `--data.test_batch_size` | Test set batch size. |
232+
| `--ckpt_path` | Model checkpoint path. |
233+
234+
235+
## ✒️ Citation
236+
You are encouraged to modify/distribute this code. However, please acknowledge this code and cite the paper appropriately.
237+
```
238+
@article{kabas2024mambaroll,
239+
title={Physics-Driven Autoregressive State Space Models for Medical Image Reconstruction},
240+
author={Bilal Kabas and Fuat Arslan and Valiyeh A. Nezhad and Saban Ozturk and Emine U. Saritas and Tolga Çukur},
241+
year={2024},
242+
journal={arXiv:2412.09331}
243+
}
244+
```
245+
246+
247+
### 💡 Acknowledgments
248+
249+
This repository uses code from the following projects:
250+
251+
- [mamba](https://github.com/state-spaces/mamba)
252+
- [deepinv](https://github.com/deepinv/deepinv)
253+
254+
<hr>
255+
Copyright © 2024, ICON Lab.

0 commit comments

Comments
 (0)