|
1 |
| -# MambaRoll |
2 |
| -Official implementation of MambaRoll: A Physics-Driven Autoregressive State Space Model for Medical Image Reconstruction |
| 1 | +<hr> |
| 2 | +<h1 align="center"> |
| 3 | + MambaRoll <br> |
| 4 | + <sub>Physics-Driven Autoregressive State Space Models for Medical Image Reconstruction</sub> |
| 5 | +</h1> |
| 6 | + |
| 7 | +<div align="center"> |
| 8 | + <a href="https://bilalkabas.github.io/" target="_blank">Bilal Kabas</a><sup>1,2</sup>   <b>·</b>   |
| 9 | + <a href="https://github.com/fuat-arslan" target="_blank">Fuat Arslan</a><sup>1,2</sup>   <b>·</b>   |
| 10 | + <a href="https://github.com/Valiyeh" target="_blank">Valiyeh A. Nezhad</a><sup>1,2</sup>   <b>·</b>   |
| 11 | + <a href="https://scholar.google.com/citations?hl=en&user=_SujLxcAAAAJ" target="_blank">Saban Ozturk</a><sup>1,2</sup>   <b>·</b>   |
| 12 | + <a href="https://kilyos.ee.bilkent.edu.tr/~saritas/" target="_blank">Emine U. Saritas</a><sup>1,2</sup>   <b>·</b>   |
| 13 | + <a href="https://kilyos.ee.bilkent.edu.tr/~cukur/" target="_blank">Tolga Çukur</a><sup>1,2</sup>   |
| 14 | + |
| 15 | + <span></span> |
| 16 | + |
| 17 | + <sup>1</sup>Bilkent University   <sup>2</sup>UMRAM <br> |
| 18 | +</div> |
| 19 | +<hr> |
| 20 | + |
| 21 | +<h3 align="center">[<a href="https://arxiv.org/abs/2412.09331">arXiv</a>]</h3> |
| 22 | + |
| 23 | +Official PyTorch implementation of **MambaRoll**, a novel physics-driven autoregressive state space model for enhanced fidelity in medical image reconstruction. In each cascade of an unrolled architecture, MambaRoll employs an autoregressive framework based on physics-driven state space modules (PSSM), where PSSMs efficiently aggregate contextual features at a given spatial scale while maintaining fidelity to acquired data, and autoregressive prediction of next-scale feature maps from earlier spatial scales enhance capture of multi-scale contextual features |
| 24 | + |
| 25 | +<p align="center"> |
| 26 | + <img src="figures/architecture.png" alt="architecture"> |
| 27 | +</p> |
| 28 | + |
| 29 | +## ⚙️ Installation |
| 30 | + |
| 31 | +This repository has been developed and tested with `CUDA 12.2` and `Python 3.12`. Below commands create a conda environment with required packages. Make sure conda is installed. |
| 32 | + |
| 33 | +``` |
| 34 | +conda env create --file requirements.yaml |
| 35 | +conda activate mambaroll |
| 36 | +``` |
| 37 | + |
| 38 | +<details> |
| 39 | +<summary>[Optional] Setting Up Faster and Memory-efficient Radon Transform</summary><br> |
| 40 | + |
| 41 | +We use a faster (over 100x) and memory-efficient (~4.5x) implementation of Radon transform ([torch-radon](https://github.com/matteo-ronchetti/torch-radon)). To install, run commands below within `mambaroll` conda environment. |
| 42 | + |
| 43 | +``` |
| 44 | +git clone https://github.com/matteo-ronchetti/torch-radon.git |
| 45 | +cd torch-radon |
| 46 | +python setup.py install |
| 47 | +``` |
| 48 | + |
| 49 | +</details> |
| 50 | + |
| 51 | +## 🗂️ Prepare dataset |
| 52 | + |
| 53 | +MambaRoll supports reconstructions for MRI and CT modalities. Therefore, we have two dataset classes: (1) `MRIDataset` and (2) `CTDataset` in `datasets.py`. |
| 54 | + |
| 55 | +### 1. MRI dataset folder structure |
| 56 | + |
| 57 | +MRI dataset has subfolders for each undersampling rate, e.g. us4x, us8x, etc. There is a separate `.npz` file for each contrast. |
| 58 | + |
| 59 | +<details> |
| 60 | +<summary>Details for npz files</summary><br> |
| 61 | + |
| 62 | +A `<contrast>.npz` file has the following keys: |
| 63 | + |
| 64 | +| Variable key | Description | Shape | |
| 65 | +|-----------------|-------------------------------------------|-------------------------------------------| |
| 66 | +| `image_fs` | Coil-combined fully-sampled MR image. | n_slices x 1 x height x width | |
| 67 | +| `image_us` | Multi-coil undersampled MR image. | n_slices x n_coils x height x width | |
| 68 | +| `us_masks` | K-space undersampling masks. | n_slices x 1 x height x width | |
| 69 | +| `coilmaps` | Coil sensitivity maps. | n_slices x n_coils x height x width | |
| 70 | +| `subject_ids` | Corresponding subject ID for each slice. | n_slices | |
| 71 | +| `us_factor` | Undersampling factor. | (Single integer value) | |
| 72 | + |
| 73 | +</details> |
| 74 | + |
| 75 | +``` |
| 76 | +fastMRI/ |
| 77 | +├── us4x/ |
| 78 | +│ ├── train/ |
| 79 | +│ │ ├── T1.npz |
| 80 | +│ │ ├── T2.npz |
| 81 | +│ │ └── FLAIR.npz |
| 82 | +│ ├── test/ |
| 83 | +│ │ ├── T1.npz |
| 84 | +│ │ ├── T2.npz |
| 85 | +│ │ └── FLAIR.npz |
| 86 | +│ └── val/ |
| 87 | +│ ├── T1.npz |
| 88 | +│ ├── T2.npz |
| 89 | +│ └── FLAIR.npz |
| 90 | +├── us8x/ |
| 91 | +│ ├── train/... |
| 92 | +│ ├── test/... |
| 93 | +│ └── val/... |
| 94 | +├── ... |
| 95 | +``` |
| 96 | + |
| 97 | + |
| 98 | + |
| 99 | +### 2. CT dataset folder structure |
| 100 | + |
| 101 | +Each split in CT dataset contains images with different undersampling rates. |
| 102 | + |
| 103 | +<details> |
| 104 | +<summary>Details for npz files</summary><br> |
| 105 | + |
| 106 | +`image_fs.npz` files have the fully-sampled data with the following key: |
| 107 | + |
| 108 | +| Variable key | Description | Shape | |
| 109 | +|-----------------|-------------------------------------------|-------------------------------------------| |
| 110 | +| `image_fs` | Fully-sampled CT image. | n_slices x 1 x height x width | |
| 111 | + |
| 112 | +A `us<us_factor>x.npz` file has the following keys: |
| 113 | + |
| 114 | +| Variable key | Description | Shape | |
| 115 | +|-----------------------|----------------------------------------------------------------------------------------------------------------|---------------------------------------------------| |
| 116 | +| `image_us` | Undersampled CT image. | n_slices x 1 x height x width | |
| 117 | +| `sinogram_us` | Corresponding sinograms for undersampled CTs. | n_slices x 1 x detector_positions x n_projections | |
| 118 | +| `projection_angles` | Projection angles at which the Radon transform performed on fully-sampled images to obtain undersampled ones. | n_slices x n_projections | |
| 119 | +| `subject_ids` | Corresponding subject ID for each slice. | n_slices | |
| 120 | +| `us_factor` | Undersampling factor. | (Single integer value) | |
| 121 | + |
| 122 | +</details> |
| 123 | + |
| 124 | +``` |
| 125 | +lodopab-ct/ |
| 126 | +├── train/ |
| 127 | +│ ├── image_fs.npz |
| 128 | +│ ├── us4x.npz |
| 129 | +│ └── us6x.npz |
| 130 | +├── test/ |
| 131 | +│ ├── image_fs.npz |
| 132 | +│ ├── us4x.npz |
| 133 | +│ └── us6x.npz |
| 134 | +└── val/ |
| 135 | + ├── image_fs.npz |
| 136 | + ├── us4x.npz |
| 137 | + └── us6x.npz |
| 138 | +``` |
| 139 | + |
| 140 | + |
| 141 | + |
| 142 | +## 🏃 Training |
| 143 | + |
| 144 | +Run the following command to start/resume training. Model checkpoints are saved under `logs/$EXP_NAME/MambaRoll/checkpoints` directory, and sample validation images are saved under `logs/$EXP_NAME/MambaRoll/val_samples`. The script supports both single and multi-GPU training. By default, it runs on a single GPU. To enable multi-GPU training, set `--trainer.devices` argument to the list of devices, e.g. `0,1,2,3`. Be aware that multi-GPU training may lead to convergence issues. Therefore, it is only recommended during inference/testing. |
| 145 | + |
| 146 | +``` |
| 147 | +python main.py fit \ |
| 148 | + --config $CONFIG_PATH \ |
| 149 | + --trainer.logger.name $EXP_NAME \ |
| 150 | + --model.mode $MODE \ |
| 151 | + --data.dataset_dir $DATA_DIR \ |
| 152 | + --data.contrast $CONTRAST \ |
| 153 | + --data.us_factor $US_FACTOR \ |
| 154 | + --data.train_batch_size $BS_TRAIN \ |
| 155 | + --data.val_batch_size $BS_VAL \ |
| 156 | + [--trainer.max_epoch $N_EPOCHS] \ |
| 157 | + [--ckpt_path $CKPT_PATH] \ |
| 158 | + [--trainer.devices $DEVICES] |
| 159 | +
|
| 160 | +``` |
| 161 | + |
| 162 | +<details> |
| 163 | +<summary>Example Commands</summary> |
| 164 | + |
| 165 | +MRI reconstruction using fastMRI dataset: |
| 166 | + |
| 167 | +``` |
| 168 | +python main.py fit \ |
| 169 | + --config configs/config_fastmri.yaml \ |
| 170 | + --trainer.logger.name fastmri_t1_us8x \ |
| 171 | + --data.dataset_dir ../datasets/fastMRI \ |
| 172 | + --data.contrast T1 \ |
| 173 | + --data.us_factor 8 \ |
| 174 | + --data.train_batch_size 1 \ |
| 175 | + --data.val_batch_size 16 \ |
| 176 | + --trainer.devices [0] |
| 177 | +``` |
| 178 | + |
| 179 | +CT reconstruction using [LoDoPaB-CT](https://zenodo.org/records/3384092) dataset: |
| 180 | + |
| 181 | +``` |
| 182 | +python main.py fit \ |
| 183 | + --config configs/config_ct.yaml \ |
| 184 | + --trainer.logger.name ct_us4x \ |
| 185 | + --data.dataset_dir ../datasets/lodopab-ct/ \ |
| 186 | + --data.us_factor 4 \ |
| 187 | + --data.train_batch_size 1 \ |
| 188 | + --data.val_batch_size 16 \ |
| 189 | + --trainer.devices [0] |
| 190 | +``` |
| 191 | +</details> |
| 192 | + |
| 193 | +### Argument descriptions |
| 194 | + |
| 195 | +| Argument | Description | |
| 196 | +|-----------------------------|--------------------------------------------------------------------------------------------------------------------------------| |
| 197 | +| `--config` | Config file path. Available config files: 'configs/config_fastmri.yaml' and 'configs/config_ct.yaml' | |
| 198 | +| `--trainer.logger.name` | Experiment name. | |
| 199 | +| `--model.mode` | Mode depending on data modality. Options: 'mri', 'ct'. | |
| 200 | +| `--data.dataset_dir` | Data set directory. | |
| 201 | +| `--data.contrast` | Source contrast, e.g. 'T1', 'T2', ... for MRI. Should match the folder name for that contrast. | |
| 202 | +| `--data.us_factor` | Undersampling factor, e.g 4, 8. | |
| 203 | +| `--data.train_batch_size` | Train set batch size. | |
| 204 | +| `--data.val_batch_size` | Validation set batch size. | |
| 205 | +| `--trainer.max_epoch` | [Optional] Number of training epochs (default: 50). | |
| 206 | +| `--ckpt_path` | [Optional] Model checkpoint path to resume training. | |
| 207 | +| `--trainer.devices` | [Optional] Device or list of devices. For multi-GPU set to the list of device ids, e.g `0,1,2,3` (default: `[0]`). | |
| 208 | + |
| 209 | + |
| 210 | +## 🧪 Testing |
| 211 | + |
| 212 | +Run the following command to start testing. The predicted images are saved under `logs/$EXP_NAME/MambaRoll/test_samples` directory. By default, the script runs on a single GPU. To enable multi-GPU testing, set `--trainer.devices` argument to the list of devices, e.g. `0,1,2,3`. |
| 213 | + |
| 214 | +``` |
| 215 | +python main.py test \ |
| 216 | + --config $CONFIG_PATH \ |
| 217 | + --model.mode $MODE \ |
| 218 | + --data.dataset_dir $DATA_DIR \ |
| 219 | + --data.contrast $CONTRAST \ |
| 220 | + --data.us_factor $US_FACTOR \ |
| 221 | + --data.test_batch_size $BS_TEST \ |
| 222 | + --ckpt_path $CKPT_PATH |
| 223 | +``` |
| 224 | + |
| 225 | +### Argument descriptions |
| 226 | + |
| 227 | +Some arguments are common to both training and testing and are not listed here. For details on those arguments, please refer to the training section. |
| 228 | + |
| 229 | +| Argument | Description | |
| 230 | +|-----------------------------|--------------------------------------------| |
| 231 | +| `--data.test_batch_size` | Test set batch size. | |
| 232 | +| `--ckpt_path` | Model checkpoint path. | |
| 233 | + |
| 234 | + |
| 235 | +## ✒️ Citation |
| 236 | +You are encouraged to modify/distribute this code. However, please acknowledge this code and cite the paper appropriately. |
| 237 | +``` |
| 238 | +@article{kabas2024mambaroll, |
| 239 | + title={Physics-Driven Autoregressive State Space Models for Medical Image Reconstruction}, |
| 240 | + author={Bilal Kabas and Fuat Arslan and Valiyeh A. Nezhad and Saban Ozturk and Emine U. Saritas and Tolga Çukur}, |
| 241 | + year={2024}, |
| 242 | + journal={arXiv:2412.09331} |
| 243 | +} |
| 244 | +``` |
| 245 | + |
| 246 | + |
| 247 | +### 💡 Acknowledgments |
| 248 | + |
| 249 | +This repository uses code from the following projects: |
| 250 | + |
| 251 | +- [mamba](https://github.com/state-spaces/mamba) |
| 252 | +- [deepinv](https://github.com/deepinv/deepinv) |
| 253 | + |
| 254 | +<hr> |
| 255 | +Copyright © 2024, ICON Lab. |
0 commit comments