Skip to content

Commit c49deda

Browse files
authored
FTW U-Net Models (#2719)
* add ftw unet models * add ftw models to docs * add unet to hubconf.py * add unet to imports * fix docs * update licenses * fix docs x2 * fix docs x3 * add tests * use smp.create_model * fix test cov * add unet to api * fix docs * fix types * fix docstring * add typing * change version added * assert with unexpected keys * missed an unexpected_keys * Update conf.py * Update conf.py
1 parent 3e9654f commit c49deda

File tree

8 files changed

+244
-1
lines changed

8 files changed

+244
-1
lines changed

docs/api/models.rst

+7
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,13 @@ Panopticon
8989
.. autofunction:: panopticon_vitb14
9090
.. autoclass:: Panopticon_Weights
9191

92+
U-Net
93+
^^^^^
94+
95+
.. autofunction:: unet
96+
.. autoclass:: Unet_Weights
97+
98+
9299
Vision Transformer
93100
^^^^^^^^^^^^^^^^^^
94101

docs/api/weights/sentinel2.csv

+4
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,7 @@ Swin_V2_B_Weights.SENTINEL2_MI_MS_SATLAS,9,`link <https://github.com/allenai/sat
3636
Swin_V2_B_Weights.SENTINEL2_MI_RGB_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
3737
Swin_V2_B_Weights.SENTINEL2_SI_MS_SATLAS,9,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
3838
Swin_V2_B_Weights.SENTINEL2_SI_RGB_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
39+
Unet_Weights.SENTINEL2_2CLASS_FTW,8,`link <https://github.com/fieldsoftheworld/ftw-baselines>`__,`link <https://arxiv.org/abs/2409.16252>`__,"CC-BY-4.0",,,
40+
Unet_Weights.SENTINEL2_2CLASS_NC_FTW,8,`link <https://github.com/fieldsoftheworld/ftw-baselines>`__,`link <https://arxiv.org/abs/2409.16252>`__,"non-commercial",,,
41+
Unet_Weights.SENTINEL2_3CLASS_FTW,8,`link <https://github.com/fieldsoftheworld/ftw-baselines>`__,`link <https://arxiv.org/abs/2409.16252>`__,"CC-BY-4.0",,,
42+
Unet_Weights.SENTINEL2_3CLASS_NC_FTW,8,`link <https://github.com/fieldsoftheworld/ftw-baselines>`__,`link <https://arxiv.org/abs/2409.16252>`__,"non-commercial",,,

hubconf.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
scalemae_large_patch16,
2121
swin_v2_b,
2222
swin_v2_t,
23+
unet,
2324
vit_base_patch14_dinov2,
2425
vit_base_patch16_224,
2526
vit_huge_patch14_224,
@@ -41,6 +42,7 @@
4142
'scalemae_large_patch16',
4243
'swin_v2_b',
4344
'swin_v2_t',
45+
'unet',
4446
'vit_base_patch14_dinov2',
4547
'vit_base_patch16_224',
4648
'vit_huge_patch14_224',
@@ -49,4 +51,4 @@
4951
'vit_small_patch16_224',
5052
)
5153

52-
dependencies = ['timm', 'torchvision']
54+
dependencies = ['timm', 'torchvision', 'segmentation_models_pytorch', 'kornia']

tests/models/test_api.py

+4
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
ScaleMAELarge16_Weights,
2222
Swin_V2_B_Weights,
2323
Swin_V2_T_Weights,
24+
Unet_Weights,
2425
ViTBase14_DINOv2_Weights,
2526
ViTBase16_Weights,
2627
ViTHuge14_Weights,
@@ -45,6 +46,7 @@
4546
scalemae_large_patch16,
4647
swin_v2_b,
4748
swin_v2_t,
49+
unet,
4850
vit_base_patch14_dinov2,
4951
vit_base_patch16_224,
5052
vit_huge_patch14_224,
@@ -68,6 +70,7 @@
6870
scalemae_large_patch16,
6971
swin_v2_t,
7072
swin_v2_b,
73+
unet,
7174
vit_base_patch14_dinov2,
7275
vit_base_patch16_224,
7376
vit_huge_patch14_224,
@@ -88,6 +91,7 @@
8891
ScaleMAELarge16_Weights,
8992
Swin_V2_T_Weights,
9093
Swin_V2_B_Weights,
94+
Unet_Weights,
9195
ViTBase14_DINOv2_Weights,
9296
ViTBase16_Weights,
9397
ViTHuge14_Weights,

tests/models/test_unet.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
from pathlib import Path
5+
6+
import pytest
7+
import segmentation_models_pytorch as smp
8+
import torch
9+
from _pytest.fixtures import SubRequest
10+
from pytest import MonkeyPatch
11+
from torchvision.models._api import WeightsEnum
12+
13+
from torchgeo.models import Unet_Weights, unet
14+
15+
16+
class TestUnet:
17+
@pytest.fixture(params=[*Unet_Weights])
18+
def weights(self, request: SubRequest) -> WeightsEnum:
19+
return request.param
20+
21+
@pytest.fixture
22+
def mocked_weights(
23+
self, tmp_path: Path, monkeypatch: MonkeyPatch, load_state_dict_from_url: None
24+
) -> WeightsEnum:
25+
weights = Unet_Weights.SENTINEL2_2CLASS_FTW
26+
path = tmp_path / f'{weights}.pth'
27+
model = smp.Unet(
28+
in_channels=weights.meta['in_chans'],
29+
encoder_name=weights.meta['encoder'],
30+
encoder_weights=None,
31+
classes=weights.meta['num_classes'],
32+
)
33+
torch.save(model.state_dict(), path)
34+
monkeypatch.setattr(weights.value, 'url', str(path))
35+
return weights
36+
37+
def test_unet(self) -> None:
38+
unet()
39+
40+
def test_unet_weights(self, mocked_weights: WeightsEnum) -> None:
41+
unet(weights=mocked_weights)
42+
43+
def test_unet_weights_different_num_classes(
44+
self, mocked_weights: WeightsEnum
45+
) -> None:
46+
unet(weights=mocked_weights, classes=20)
47+
48+
def test_bands(self, weights: WeightsEnum) -> None:
49+
if 'bands' in weights.meta:
50+
assert len(weights.meta['bands']) == weights.meta['in_chans']
51+
52+
def test_transforms(self, weights: WeightsEnum) -> None:
53+
c = weights.meta['in_chans']
54+
sample = {
55+
'image': torch.arange(c * 256 * 256, dtype=torch.float).view(c, 256, 256)
56+
}
57+
weights.transforms(sample)
58+
59+
@pytest.mark.slow
60+
def test_unet_download(self, weights: WeightsEnum) -> None:
61+
unet(weights=weights)

torchgeo/models/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from .scale_mae import ScaleMAE, ScaleMAELarge16_Weights, scalemae_large_patch16
3333
from .swin import Swin_V2_B_Weights, Swin_V2_T_Weights, swin_v2_b, swin_v2_t
34+
from .unet import Unet_Weights, unet
3435
from .vit import (
3536
ViTBase14_DINOv2_Weights,
3637
ViTBase16_Weights,
@@ -72,6 +73,7 @@
7273
'ScaleMAELarge16_Weights',
7374
'Swin_V2_B_Weights',
7475
'Swin_V2_T_Weights',
76+
'Unet_Weights',
7577
'ViTBase14_DINOv2_Weights',
7678
'ViTBase16_Weights',
7779
'ViTHuge14_Weights',
@@ -96,6 +98,7 @@
9698
'scalemae_large_patch16',
9799
'swin_v2_b',
98100
'swin_v2_t',
101+
'unet',
99102
'vit_base_patch14_dinov2',
100103
'vit_base_patch16_224',
101104
'vit_huge_patch14_224',

torchgeo/models/api.py

+4
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
)
3838
from .scale_mae import ScaleMAELarge16_Weights, scalemae_large_patch16
3939
from .swin import Swin_V2_B_Weights, Swin_V2_T_Weights, swin_v2_b, swin_v2_t
40+
from .unet import Unet_Weights, unet
4041
from .vit import (
4142
ViTBase14_DINOv2_Weights,
4243
ViTBase16_Weights,
@@ -67,6 +68,7 @@
6768
'scalemae_large_patch16': scalemae_large_patch16,
6869
'swin_v2_t': swin_v2_t,
6970
'swin_v2_b': swin_v2_b,
71+
'unet': unet,
7072
'vit_small_patch16_224': vit_small_patch16_224,
7173
'vit_base_patch14_dinov2': vit_base_patch14_dinov2,
7274
'vit_base_patch16_224': vit_base_patch16_224,
@@ -88,6 +90,7 @@
8890
scalemae_large_patch16: ScaleMAELarge16_Weights,
8991
swin_v2_t: Swin_V2_T_Weights,
9092
swin_v2_b: Swin_V2_B_Weights,
93+
unet: Unet_Weights,
9194
vit_small_patch16_224: ViTSmall16_Weights,
9295
vit_base_patch14_dinov2: ViTBase14_DINOv2_Weights,
9396
vit_base_patch16_224: ViTBase16_Weights,
@@ -106,6 +109,7 @@
106109
'scalemae_large_patch16': ScaleMAELarge16_Weights,
107110
'swin_v2_t': Swin_V2_T_Weights,
108111
'swin_v2_b': Swin_V2_B_Weights,
112+
'unet': Unet_Weights,
109113
'vit_small_patch16_224': ViTSmall16_Weights,
110114
'vit_base_patch14_dinov2': ViTBase14_DINOv2_Weights,
111115
'vit_base_patch16_224': ViTBase16_Weights,

torchgeo/models/unet.py

+158
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
"""Pre-trained U-Net models."""
5+
6+
from typing import Any
7+
8+
import kornia.augmentation as K
9+
import segmentation_models_pytorch as smp
10+
import torch
11+
from segmentation_models_pytorch import Unet
12+
from torchvision.models._api import Weights, WeightsEnum
13+
14+
# Specified in https://github.com/fieldsoftheworld/ftw-baselines
15+
# First 4 S2 bands are for image t1 and last 4 bands are for image t2
16+
_ftw_sentinel2_bands = ['B4', 'B3', 'B2', 'B8A', 'B4', 'B3', 'B2', 'B8A']
17+
18+
# https://github.com/fieldsoftheworld/ftw-baselines/blob/main/src/ftw/datamodules.py
19+
# Normalization by 3k (for S2 uint16 input)
20+
_ftw_transforms = K.AugmentationSequential(
21+
K.Normalize(mean=torch.tensor(0.0), std=torch.tensor(3000.0)), data_keys=None
22+
)
23+
24+
# https://github.com/pytorch/vision/pull/6883
25+
# https://github.com/pytorch/vision/pull/7107
26+
# Can be removed once torchvision>=0.15 is required
27+
Weights.__deepcopy__ = lambda *args, **kwargs: args[0]
28+
29+
30+
class Unet_Weights(WeightsEnum): # type: ignore[misc]
31+
"""U-Net weights.
32+
33+
For `smp <https://github.com/qubvel-org/segmentation_models.pytorch>`_
34+
*Unet* implementation.
35+
36+
.. versionadded:: 0.8
37+
"""
38+
39+
SENTINEL2_2CLASS_FTW = Weights(
40+
url='https://huggingface.co/torchgeo/ftw/resolve/d2fdab6ea9d9cd38b491292cc9a5c8642533cef5/commercial/2-class/sentinel2_unet_effb3-9c04b7c6.pth',
41+
transforms=_ftw_transforms,
42+
meta={
43+
'dataset': 'FTW',
44+
'in_chans': 8,
45+
'num_classes': 2,
46+
'model': 'U-Net',
47+
'encoder': 'efficientnet-b3',
48+
'publication': 'https://arxiv.org/abs/2409.16252',
49+
'repo': 'https://github.com/fieldsoftheworld/ftw-baselines',
50+
'bands': _ftw_sentinel2_bands,
51+
'license': 'CC-BY-4.0',
52+
},
53+
)
54+
SENTINEL2_3CLASS_FTW = Weights(
55+
url='https://huggingface.co/torchgeo/ftw/resolve/d2fdab6ea9d9cd38b491292cc9a5c8642533cef5/commercial/3-class/sentinel2_unet_effb3-5d591cbb.pth',
56+
transforms=_ftw_transforms,
57+
meta={
58+
'dataset': 'FTW',
59+
'in_chans': 8,
60+
'num_classes': 3,
61+
'model': 'U-Net',
62+
'encoder': 'efficientnet-b3',
63+
'publication': 'https://arxiv.org/abs/2409.16252',
64+
'repo': 'https://github.com/fieldsoftheworld/ftw-baselines',
65+
'bands': _ftw_sentinel2_bands,
66+
'license': 'CC-BY-4.0',
67+
},
68+
)
69+
SENTINEL2_2CLASS_NC_FTW = Weights(
70+
url='https://huggingface.co/torchgeo/ftw/resolve/d2fdab6ea9d9cd38b491292cc9a5c8642533cef5/noncommercial/2-class/sentinel2_unet_effb3-bf010a31.pth',
71+
transforms=_ftw_transforms,
72+
meta={
73+
'dataset': 'FTW',
74+
'in_chans': 8,
75+
'num_classes': 2,
76+
'model': 'U-Net',
77+
'encoder': 'efficientnet-b3',
78+
'publication': 'https://arxiv.org/abs/2409.16252',
79+
'repo': 'https://github.com/fieldsoftheworld/ftw-baselines',
80+
'bands': _ftw_sentinel2_bands,
81+
'license': 'non-commercial',
82+
},
83+
)
84+
SENTINEL2_3CLASS_NC_FTW = Weights(
85+
url='https://huggingface.co/torchgeo/ftw/resolve/d2fdab6ea9d9cd38b491292cc9a5c8642533cef5/noncommercial/3-class/sentinel2_unet_effb3-ed36f465.pth',
86+
transforms=_ftw_transforms,
87+
meta={
88+
'dataset': 'FTW',
89+
'in_chans': 8,
90+
'num_classes': 3,
91+
'model': 'U-Net',
92+
'encoder': 'efficientnet-b3',
93+
'publication': 'https://arxiv.org/abs/2409.16252',
94+
'repo': 'https://github.com/fieldsoftheworld/ftw-baselines',
95+
'bands': _ftw_sentinel2_bands,
96+
'license': 'non-commercial',
97+
},
98+
)
99+
100+
101+
def unet(
102+
weights: Unet_Weights | None = None,
103+
classes: int | None = None,
104+
*args: Any,
105+
**kwargs: Any,
106+
) -> Unet:
107+
"""U-Net model.
108+
109+
If you use this model in your research, please cite the following paper:
110+
111+
* https://arxiv.org/abs/1505.04597
112+
113+
.. versionadded:: 0.8
114+
115+
Args:
116+
weights: Pre-trained model weights to use.
117+
classes: Number of output classes. If not specified, the number of
118+
classes will be inferred from the weights.
119+
*args: Additional arguments to pass to ``segmentation_models_pytorch.create_model``
120+
**kwargs: Additional keyword arguments to pass to ``segmentation_models_pytorch.create_model``
121+
122+
Returns:
123+
A U-Net model.
124+
"""
125+
kwargs['arch'] = 'Unet'
126+
127+
if weights:
128+
kwargs['encoder_weights'] = None
129+
kwargs['in_channels'] = weights.meta['in_chans']
130+
kwargs['encoder_name'] = weights.meta['encoder']
131+
kwargs['classes'] = weights.meta['num_classes'] if classes is None else classes
132+
else:
133+
kwargs['classes'] = 1 if classes is None else classes
134+
135+
model: Unet = smp.create_model(*args, **kwargs)
136+
137+
if weights:
138+
state_dict = weights.get_state_dict(progress=True)
139+
140+
# Load full pretrained model
141+
if kwargs['classes'] == weights.meta['num_classes']:
142+
missing_keys, unexpected_keys = model.load_state_dict(
143+
state_dict, strict=True
144+
)
145+
# Random initialize segmentation head for new task
146+
else:
147+
del state_dict['segmentation_head.0.weight']
148+
del state_dict['segmentation_head.0.bias']
149+
missing_keys, unexpected_keys = model.load_state_dict(
150+
state_dict, strict=False
151+
)
152+
assert set(missing_keys) <= {
153+
'segmentation_head.0.weight',
154+
'segmentation_head.0.bias',
155+
}
156+
assert not unexpected_keys
157+
158+
return model

0 commit comments

Comments
 (0)