|
| 1 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +# Licensed under the MIT License. |
| 3 | + |
| 4 | +"""Pre-trained U-Net models.""" |
| 5 | + |
| 6 | +from typing import Any |
| 7 | + |
| 8 | +import kornia.augmentation as K |
| 9 | +import segmentation_models_pytorch as smp |
| 10 | +import torch |
| 11 | +from segmentation_models_pytorch import Unet |
| 12 | +from torchvision.models._api import Weights, WeightsEnum |
| 13 | + |
| 14 | +# Specified in https://github.com/fieldsoftheworld/ftw-baselines |
| 15 | +# First 4 S2 bands are for image t1 and last 4 bands are for image t2 |
| 16 | +_ftw_sentinel2_bands = ['B4', 'B3', 'B2', 'B8A', 'B4', 'B3', 'B2', 'B8A'] |
| 17 | + |
| 18 | +# https://github.com/fieldsoftheworld/ftw-baselines/blob/main/src/ftw/datamodules.py |
| 19 | +# Normalization by 3k (for S2 uint16 input) |
| 20 | +_ftw_transforms = K.AugmentationSequential( |
| 21 | + K.Normalize(mean=torch.tensor(0.0), std=torch.tensor(3000.0)), data_keys=None |
| 22 | +) |
| 23 | + |
| 24 | +# https://github.com/pytorch/vision/pull/6883 |
| 25 | +# https://github.com/pytorch/vision/pull/7107 |
| 26 | +# Can be removed once torchvision>=0.15 is required |
| 27 | +Weights.__deepcopy__ = lambda *args, **kwargs: args[0] |
| 28 | + |
| 29 | + |
| 30 | +class Unet_Weights(WeightsEnum): # type: ignore[misc] |
| 31 | + """U-Net weights. |
| 32 | +
|
| 33 | + For `smp <https://github.com/qubvel-org/segmentation_models.pytorch>`_ |
| 34 | + *Unet* implementation. |
| 35 | +
|
| 36 | + .. versionadded:: 0.8 |
| 37 | + """ |
| 38 | + |
| 39 | + SENTINEL2_2CLASS_FTW = Weights( |
| 40 | + url='https://huggingface.co/torchgeo/ftw/resolve/d2fdab6ea9d9cd38b491292cc9a5c8642533cef5/commercial/2-class/sentinel2_unet_effb3-9c04b7c6.pth', |
| 41 | + transforms=_ftw_transforms, |
| 42 | + meta={ |
| 43 | + 'dataset': 'FTW', |
| 44 | + 'in_chans': 8, |
| 45 | + 'num_classes': 2, |
| 46 | + 'model': 'U-Net', |
| 47 | + 'encoder': 'efficientnet-b3', |
| 48 | + 'publication': 'https://arxiv.org/abs/2409.16252', |
| 49 | + 'repo': 'https://github.com/fieldsoftheworld/ftw-baselines', |
| 50 | + 'bands': _ftw_sentinel2_bands, |
| 51 | + 'license': 'CC-BY-4.0', |
| 52 | + }, |
| 53 | + ) |
| 54 | + SENTINEL2_3CLASS_FTW = Weights( |
| 55 | + url='https://huggingface.co/torchgeo/ftw/resolve/d2fdab6ea9d9cd38b491292cc9a5c8642533cef5/commercial/3-class/sentinel2_unet_effb3-5d591cbb.pth', |
| 56 | + transforms=_ftw_transforms, |
| 57 | + meta={ |
| 58 | + 'dataset': 'FTW', |
| 59 | + 'in_chans': 8, |
| 60 | + 'num_classes': 3, |
| 61 | + 'model': 'U-Net', |
| 62 | + 'encoder': 'efficientnet-b3', |
| 63 | + 'publication': 'https://arxiv.org/abs/2409.16252', |
| 64 | + 'repo': 'https://github.com/fieldsoftheworld/ftw-baselines', |
| 65 | + 'bands': _ftw_sentinel2_bands, |
| 66 | + 'license': 'CC-BY-4.0', |
| 67 | + }, |
| 68 | + ) |
| 69 | + SENTINEL2_2CLASS_NC_FTW = Weights( |
| 70 | + url='https://huggingface.co/torchgeo/ftw/resolve/d2fdab6ea9d9cd38b491292cc9a5c8642533cef5/noncommercial/2-class/sentinel2_unet_effb3-bf010a31.pth', |
| 71 | + transforms=_ftw_transforms, |
| 72 | + meta={ |
| 73 | + 'dataset': 'FTW', |
| 74 | + 'in_chans': 8, |
| 75 | + 'num_classes': 2, |
| 76 | + 'model': 'U-Net', |
| 77 | + 'encoder': 'efficientnet-b3', |
| 78 | + 'publication': 'https://arxiv.org/abs/2409.16252', |
| 79 | + 'repo': 'https://github.com/fieldsoftheworld/ftw-baselines', |
| 80 | + 'bands': _ftw_sentinel2_bands, |
| 81 | + 'license': 'non-commercial', |
| 82 | + }, |
| 83 | + ) |
| 84 | + SENTINEL2_3CLASS_NC_FTW = Weights( |
| 85 | + url='https://huggingface.co/torchgeo/ftw/resolve/d2fdab6ea9d9cd38b491292cc9a5c8642533cef5/noncommercial/3-class/sentinel2_unet_effb3-ed36f465.pth', |
| 86 | + transforms=_ftw_transforms, |
| 87 | + meta={ |
| 88 | + 'dataset': 'FTW', |
| 89 | + 'in_chans': 8, |
| 90 | + 'num_classes': 3, |
| 91 | + 'model': 'U-Net', |
| 92 | + 'encoder': 'efficientnet-b3', |
| 93 | + 'publication': 'https://arxiv.org/abs/2409.16252', |
| 94 | + 'repo': 'https://github.com/fieldsoftheworld/ftw-baselines', |
| 95 | + 'bands': _ftw_sentinel2_bands, |
| 96 | + 'license': 'non-commercial', |
| 97 | + }, |
| 98 | + ) |
| 99 | + |
| 100 | + |
| 101 | +def unet( |
| 102 | + weights: Unet_Weights | None = None, |
| 103 | + classes: int | None = None, |
| 104 | + *args: Any, |
| 105 | + **kwargs: Any, |
| 106 | +) -> Unet: |
| 107 | + """U-Net model. |
| 108 | +
|
| 109 | + If you use this model in your research, please cite the following paper: |
| 110 | +
|
| 111 | + * https://arxiv.org/abs/1505.04597 |
| 112 | +
|
| 113 | + .. versionadded:: 0.8 |
| 114 | +
|
| 115 | + Args: |
| 116 | + weights: Pre-trained model weights to use. |
| 117 | + classes: Number of output classes. If not specified, the number of |
| 118 | + classes will be inferred from the weights. |
| 119 | + *args: Additional arguments to pass to ``segmentation_models_pytorch.create_model`` |
| 120 | + **kwargs: Additional keyword arguments to pass to ``segmentation_models_pytorch.create_model`` |
| 121 | +
|
| 122 | + Returns: |
| 123 | + A U-Net model. |
| 124 | + """ |
| 125 | + kwargs['arch'] = 'Unet' |
| 126 | + |
| 127 | + if weights: |
| 128 | + kwargs['encoder_weights'] = None |
| 129 | + kwargs['in_channels'] = weights.meta['in_chans'] |
| 130 | + kwargs['encoder_name'] = weights.meta['encoder'] |
| 131 | + kwargs['classes'] = weights.meta['num_classes'] if classes is None else classes |
| 132 | + else: |
| 133 | + kwargs['classes'] = 1 if classes is None else classes |
| 134 | + |
| 135 | + model: Unet = smp.create_model(*args, **kwargs) |
| 136 | + |
| 137 | + if weights: |
| 138 | + state_dict = weights.get_state_dict(progress=True) |
| 139 | + |
| 140 | + # Load full pretrained model |
| 141 | + if kwargs['classes'] == weights.meta['num_classes']: |
| 142 | + missing_keys, unexpected_keys = model.load_state_dict( |
| 143 | + state_dict, strict=True |
| 144 | + ) |
| 145 | + # Random initialize segmentation head for new task |
| 146 | + else: |
| 147 | + del state_dict['segmentation_head.0.weight'] |
| 148 | + del state_dict['segmentation_head.0.bias'] |
| 149 | + missing_keys, unexpected_keys = model.load_state_dict( |
| 150 | + state_dict, strict=False |
| 151 | + ) |
| 152 | + assert set(missing_keys) <= { |
| 153 | + 'segmentation_head.0.weight', |
| 154 | + 'segmentation_head.0.bias', |
| 155 | + } |
| 156 | + assert not unexpected_keys |
| 157 | + |
| 158 | + return model |
0 commit comments