Skip to content

Commit 4455851

Browse files
authored
Merge pull request #133 from allenai/favyen/20250514-pastis
Add PASTIS dataset, + evaluation for Helios
2 parents 9c6cbba + 039ee6b commit 4455851

File tree

5 files changed

+895
-0
lines changed

5 files changed

+895
-0
lines changed
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
model:
2+
class_path: rslearn.train.lightning_module.RslearnLightningModule
3+
init_args:
4+
model:
5+
class_path: rslearn.models.multitask.MultiTaskModel
6+
init_args:
7+
encoder:
8+
- class_path: rslearn.models.simple_time_series.SimpleTimeSeries
9+
init_args:
10+
encoder:
11+
class_path: rslearn.models.swin.Swin
12+
init_args:
13+
pretrained: true
14+
input_channels: 9
15+
output_layers: [1, 3, 5, 7]
16+
image_channels: 9
17+
groups: [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]]
18+
decoders:
19+
segment:
20+
- class_path: rslearn.models.unet.UNetDecoder
21+
init_args:
22+
in_channels: [[4, 128], [8, 256], [16, 512], [32, 1024]]
23+
out_channels: 20
24+
conv_layers_per_resolution: 2
25+
- class_path: rslearn.train.tasks.segmentation.SegmentationHead
26+
lr: 0.0001
27+
plateau: true
28+
plateau_factor: 0.2
29+
plateau_patience: 2
30+
plateau_min_lr: 0
31+
plateau_cooldown: 10
32+
restore_config:
33+
restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth
34+
remap_prefixes:
35+
- ["backbone.backbone.backbone.", "encoder.0.encoder.model."]
36+
data:
37+
class_path: rslearn.train.data_module.RslearnDataModule
38+
init_args:
39+
path: /weka/dfive-default/rslearn-eai/datasets/pastis/rslearn_dataset/
40+
inputs:
41+
sentinel2_0:
42+
data_type: "raster"
43+
layers: ["sentinel2"]
44+
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"]
45+
passthrough: true
46+
sentinel2_1:
47+
data_type: "raster"
48+
layers: ["sentinel2.1"]
49+
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"]
50+
passthrough: true
51+
sentinel2_2:
52+
data_type: "raster"
53+
layers: ["sentinel2.2"]
54+
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"]
55+
passthrough: true
56+
sentinel2_3:
57+
data_type: "raster"
58+
layers: ["sentinel2.3"]
59+
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"]
60+
passthrough: true
61+
sentinel2_4:
62+
data_type: "raster"
63+
layers: ["sentinel2.4"]
64+
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"]
65+
passthrough: true
66+
sentinel2_5:
67+
data_type: "raster"
68+
layers: ["sentinel2.5"]
69+
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"]
70+
passthrough: true
71+
sentinel2_6:
72+
data_type: "raster"
73+
layers: ["sentinel2.6"]
74+
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"]
75+
passthrough: true
76+
sentinel2_7:
77+
data_type: "raster"
78+
layers: ["sentinel2.7"]
79+
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"]
80+
passthrough: true
81+
sentinel2_8:
82+
data_type: "raster"
83+
layers: ["sentinel2.8"]
84+
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"]
85+
passthrough: true
86+
sentinel2_9:
87+
data_type: "raster"
88+
layers: ["sentinel2.9"]
89+
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"]
90+
passthrough: true
91+
sentinel2_10:
92+
data_type: "raster"
93+
layers: ["sentinel2.10"]
94+
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"]
95+
passthrough: true
96+
sentinel2_11:
97+
data_type: "raster"
98+
layers: ["sentinel2.11"]
99+
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"]
100+
passthrough: true
101+
targets:
102+
data_type: "raster"
103+
layers: ["label"]
104+
bands: ["class"]
105+
is_target: true
106+
task:
107+
class_path: rslearn.train.tasks.multi_task.MultiTask
108+
init_args:
109+
tasks:
110+
segment:
111+
class_path: rslearn.train.tasks.segmentation.SegmentationTask
112+
init_args:
113+
num_classes: 20
114+
remap_values: [[0, 1], [0, 255]]
115+
zero_is_invalid: true
116+
metric_kwargs:
117+
average: "micro"
118+
enable_miou_metric: true
119+
input_mapping:
120+
segment:
121+
targets: "targets"
122+
batch_size: 8
123+
num_workers: 32
124+
default_config:
125+
transforms:
126+
- class_path: rslearn.train.transforms.normalize.Normalize
127+
init_args:
128+
mean: 0
129+
std: 3000
130+
valid_range: [0, 1]
131+
bands: [0, 1, 2]
132+
selectors:
133+
- sentinel2_0
134+
- sentinel2_1
135+
- sentinel2_2
136+
- sentinel2_3
137+
- sentinel2_4
138+
- sentinel2_5
139+
- sentinel2_6
140+
- sentinel2_7
141+
- sentinel2_8
142+
- sentinel2_9
143+
- sentinel2_10
144+
- sentinel2_11
145+
- class_path: rslearn.train.transforms.normalize.Normalize
146+
init_args:
147+
mean: 0
148+
std: 8160
149+
valid_range: [0, 1]
150+
bands: [3, 4, 5, 6, 7, 8]
151+
selectors:
152+
- sentinel2_0
153+
- sentinel2_1
154+
- sentinel2_2
155+
- sentinel2_3
156+
- sentinel2_4
157+
- sentinel2_5
158+
- sentinel2_6
159+
- sentinel2_7
160+
- sentinel2_8
161+
- sentinel2_9
162+
- sentinel2_10
163+
- sentinel2_11
164+
- class_path: rslearn.train.transforms.concatenate.Concatenate
165+
init_args:
166+
selections:
167+
sentinel2_0: []
168+
sentinel2_1: []
169+
sentinel2_2: []
170+
sentinel2_3: []
171+
sentinel2_4: []
172+
sentinel2_5: []
173+
sentinel2_6: []
174+
sentinel2_7: []
175+
sentinel2_8: []
176+
sentinel2_9: []
177+
sentinel2_10: []
178+
sentinel2_11: []
179+
output_selector: image
180+
train_config:
181+
patch_size: 64
182+
transforms:
183+
- class_path: rslearn.train.transforms.normalize.Normalize
184+
init_args:
185+
mean: 0
186+
std: 3000
187+
valid_range: [0, 1]
188+
bands: [0, 1, 2]
189+
selectors:
190+
- sentinel2_0
191+
- sentinel2_1
192+
- sentinel2_2
193+
- sentinel2_3
194+
- sentinel2_4
195+
- sentinel2_5
196+
- sentinel2_6
197+
- sentinel2_7
198+
- sentinel2_8
199+
- sentinel2_9
200+
- sentinel2_10
201+
- sentinel2_11
202+
- class_path: rslearn.train.transforms.normalize.Normalize
203+
init_args:
204+
mean: 0
205+
std: 8160
206+
valid_range: [0, 1]
207+
bands: [3, 4, 5, 6, 7, 8]
208+
selectors:
209+
- sentinel2_0
210+
- sentinel2_1
211+
- sentinel2_2
212+
- sentinel2_3
213+
- sentinel2_4
214+
- sentinel2_5
215+
- sentinel2_6
216+
- sentinel2_7
217+
- sentinel2_8
218+
- sentinel2_9
219+
- sentinel2_10
220+
- sentinel2_11
221+
- class_path: rslearn.train.transforms.concatenate.Concatenate
222+
init_args:
223+
selections:
224+
sentinel2_0: []
225+
sentinel2_1: []
226+
sentinel2_2: []
227+
sentinel2_3: []
228+
sentinel2_4: []
229+
sentinel2_5: []
230+
sentinel2_6: []
231+
sentinel2_7: []
232+
sentinel2_8: []
233+
sentinel2_9: []
234+
sentinel2_10: []
235+
sentinel2_11: []
236+
output_selector: image
237+
- class_path: rslearn.train.transforms.flip.Flip
238+
init_args:
239+
image_selectors: ["image", "target/segment/classes", "target/segment/valid"]
240+
groups: ["fold1", "fold2", "fold3"]
241+
val_config:
242+
patch_size: 128
243+
groups: ["fold4"]
244+
test_config:
245+
patch_size: 128
246+
groups: ["fold5"]
247+
trainer:
248+
max_epochs: 500
249+
callbacks:
250+
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
251+
init_args:
252+
logging_interval: "epoch"
253+
- class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze
254+
init_args:
255+
module_selector: ["model", "encoder", 0, "encoder", "model"]
256+
unfreeze_at_epoch: 5
257+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
258+
init_args:
259+
save_top_k: 1
260+
save_last: true
261+
monitor: val_segment/accuracy
262+
mode: max
263+
rslp_project: helios_finetuning
264+
rslp_experiment: placeholder

0 commit comments

Comments
 (0)