Skip to content

Commit 4e8122d

Browse files
committed
Add notebook using DVCLive
1 parent c62c110 commit 4e8122d

File tree

1 file changed

+303
-0
lines changed

1 file changed

+303
-0
lines changed

notebooks/TrainSegModel.ipynb

+303
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import os\n",
10+
"import shutil\n",
11+
"from functools import partial\n",
12+
"from pathlib import Path\n",
13+
"import warnings\n",
14+
"\n",
15+
"import numpy as np\n",
16+
"import torch\n",
17+
"from box import ConfigBox\n",
18+
"from dvclive import Live\n",
19+
"from dvclive.fastai import DVCLiveCallback\n",
20+
"from fastai.data.all import Normalize, get_files\n",
21+
"from fastai.metrics import DiceMulti\n",
22+
"from fastai.vision.all import (Resize, SegmentationDataLoaders,\n",
23+
" imagenet_stats, models, unet_learner)\n",
24+
"from ruamel.yaml import YAML\n",
25+
"from PIL import Image\n",
26+
"\n",
27+
"os.chdir(\"..\")\n",
28+
"warnings.filterwarnings(\"ignore\")"
29+
]
30+
},
31+
{
32+
"attachments": {},
33+
"cell_type": "markdown",
34+
"metadata": {},
35+
"source": [
36+
"### Load data and split it into train/test\n",
37+
"\n",
38+
"We have some [data in DVC](https://dvc.org/doc/start/data-management/data-versioning) that we can pull. \n",
39+
"\n",
40+
"This data includes:\n",
41+
"* satellite images\n",
42+
"* masks of the swimming pools in each satellite image\n",
43+
"\n",
44+
"DVC can help connect your data to your repo, but it isn't necessary to have your data in DVC to start tracking experiments with DVC and DVCLive."
45+
]
46+
},
47+
{
48+
"cell_type": "code",
49+
"execution_count": null,
50+
"metadata": {},
51+
"outputs": [],
52+
"source": [
53+
"!dvc pull"
54+
]
55+
},
56+
{
57+
"cell_type": "code",
58+
"execution_count": null,
59+
"metadata": {},
60+
"outputs": [],
61+
"source": [
62+
"test_regions = [\"REGION_1-\"]\n",
63+
"\n",
64+
"img_fpaths = get_files(Path(\"data\") / \"pool_data\" / \"images\", extensions=\".jpg\")\n",
65+
"\n",
66+
"train_data_dir = Path(\"data\") / \"train_data\"\n",
67+
"train_data_dir.mkdir(exist_ok=True)\n",
68+
"test_data_dir = Path(\"data\") / \"test_data\"\n",
69+
"test_data_dir.mkdir(exist_ok=True)\n",
70+
"for img_path in img_fpaths:\n",
71+
" msk_path = Path(\"data\") / \"pool_data\" / \"masks\" / f\"{img_path.stem}.png\"\n",
72+
" if any(region in str(img_path) for region in test_regions):\n",
73+
" shutil.copy(img_path, test_data_dir)\n",
74+
" shutil.copy(msk_path, test_data_dir)\n",
75+
" else:\n",
76+
" shutil.copy(img_path, train_data_dir)\n",
77+
" shutil.copy(msk_path, train_data_dir)"
78+
]
79+
},
80+
{
81+
"attachments": {},
82+
"cell_type": "markdown",
83+
"metadata": {},
84+
"source": [
85+
"### Create a data loader\n",
86+
"\n",
87+
"Load and prepare the images and masks by creating a data loader."
88+
]
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": null,
93+
"metadata": {},
94+
"outputs": [],
95+
"source": [
96+
"def get_mask_path(x, train_data_dir):\n",
97+
" return Path(train_data_dir) / f\"{Path(x).stem}.png\""
98+
]
99+
},
100+
{
101+
"cell_type": "code",
102+
"execution_count": null,
103+
"metadata": {},
104+
"outputs": [],
105+
"source": [
106+
"bs = 8\n",
107+
"valid_pct = 0.20\n",
108+
"img_size = 256\n",
109+
"\n",
110+
"data_loader = SegmentationDataLoaders.from_label_func(\n",
111+
" path=train_data_dir,\n",
112+
" fnames=get_files(train_data_dir, extensions=\".jpg\"),\n",
113+
" label_func=partial(get_mask_path, train_data_dir=train_data_dir),\n",
114+
" codes=[\"not-pool\", \"pool\"],\n",
115+
" bs=bs,\n",
116+
" valid_pct=valid_pct,\n",
117+
" item_tfms=Resize(img_size),\n",
118+
" batch_tfms=[\n",
119+
" Normalize.from_stats(*imagenet_stats),\n",
120+
" ],\n",
121+
" )"
122+
]
123+
},
124+
{
125+
"attachments": {},
126+
"cell_type": "markdown",
127+
"metadata": {},
128+
"source": [
129+
"### Review a sample batch of data\n",
130+
"\n",
131+
"Below are some examples of the images overlaid with their masks."
132+
]
133+
},
134+
{
135+
"cell_type": "code",
136+
"execution_count": null,
137+
"metadata": {},
138+
"outputs": [],
139+
"source": [
140+
"data_loader.show_batch(alpha=0.7)"
141+
]
142+
},
143+
{
144+
"attachments": {},
145+
"cell_type": "markdown",
146+
"metadata": {},
147+
"source": [
148+
"### Train multiple models with different learning rates using `DVCLiveCallback`\n",
149+
"\n",
150+
"Set up model training, using DVCLive to capture the results of each experiment."
151+
]
152+
},
153+
{
154+
"cell_type": "code",
155+
"execution_count": null,
156+
"metadata": {},
157+
"outputs": [],
158+
"source": [
159+
"def dice(mask_pred, mask_true, classes=[0, 1], eps=1e-6):\n",
160+
" dice_list = []\n",
161+
" for c in classes:\n",
162+
" y_true = mask_true == c\n",
163+
" y_pred = mask_pred == c\n",
164+
" intersection = 2.0 * np.sum(y_true * y_pred)\n",
165+
" dice = intersection / (np.sum(y_true) + np.sum(y_pred) + eps)\n",
166+
" dice_list.append(dice)\n",
167+
" return np.mean(dice_list)\n",
168+
"\n",
169+
"\n",
170+
"def evaluate(learn):\n",
171+
" test_img_fpaths = sorted(get_files(Path(\"data\") / \"test_data\", extensions=\".jpg\"))\n",
172+
" test_dl = learn.dls.test_dl(test_img_fpaths)\n",
173+
" preds, _ = learn.get_preds(dl=test_dl)\n",
174+
" masks_pred = np.array(preds[:, 1, :] > 0.5, dtype=np.uint8)\n",
175+
" test_mask_fpaths = [\n",
176+
" get_mask_path(fpath, Path(\"data\") / \"test_data\") for fpath in test_img_fpaths\n",
177+
" ]\n",
178+
" masks_true = [Image.open(mask_path) for mask_path in test_mask_fpaths]\n",
179+
"\n",
180+
" dice_multi = 0.0\n",
181+
" for ii in range(len(masks_true)):\n",
182+
" mask_pred, mask_true = masks_pred[ii], masks_true[ii]\n",
183+
" mask_pred = np.array(\n",
184+
" Image.fromarray(mask_pred).resize((mask_true.shape[1], mask_true.shape[0])),\n",
185+
" dtype=int\n",
186+
" )\n",
187+
" mask_true = np.array(mask_true, dtype=int)\n",
188+
" dice_multi += dice(mask_true, mask_pred) / len(masks_true)\n",
189+
"\n",
190+
" return dice_multi"
191+
]
192+
},
193+
{
194+
"cell_type": "code",
195+
"execution_count": null,
196+
"metadata": {},
197+
"outputs": [],
198+
"source": [
199+
"train_arch = 'shufflenet_v2_x2_0'\n",
200+
"\n",
201+
"for base_lr in [0.001, 0.005, 0.01]:\n",
202+
" # initialize dvclive, optionally provide output path, and show report in notebook\n",
203+
" # don't save dvc experiment until post-training metrics below\n",
204+
" with Live(\"results/train\", report=\"notebook\", save_dvc_exp=False) as live:\n",
205+
" # log a parameter\n",
206+
" live.log_param(\"train_arch\", train_arch)\n",
207+
" fine_tune_args = {\n",
208+
" 'epochs': 8,\n",
209+
" 'base_lr': base_lr\n",
210+
" }\n",
211+
" # log a dict of parameters\n",
212+
" live.log_params(fine_tune_args)\n",
213+
"\n",
214+
" learn = unet_learner(data_loader, \n",
215+
" arch=getattr(models, train_arch), \n",
216+
" metrics=DiceMulti)\n",
217+
" # train model and automatically capture metrics with DVCLiveCallback\n",
218+
" learn.fine_tune(\n",
219+
" **fine_tune_args,\n",
220+
" cbs=[DVCLiveCallback(live=live)])\n",
221+
"\n",
222+
" # save model artifact to dvc\n",
223+
" models_dir = Path(\"models\")\n",
224+
" models_dir.mkdir(exist_ok=True)\n",
225+
" learn.export(fname=(models_dir / \"model.pkl\").absolute())\n",
226+
" torch.save(learn.model, (models_dir / \"model.pth\").absolute())\n",
227+
" live.log_artifact(\n",
228+
" str(models_dir / \"model.pkl\"),\n",
229+
" type=\"model\",\n",
230+
" name=\"pool-segmentation\",\n",
231+
" desc=\"This is a Computer Vision (CV) model that's segmenting out swimming pools from satellite images.\",\n",
232+
" labels=[\"cv\", \"segmentation\", \"satellite-images\", \"unet\"],\n",
233+
" )\n",
234+
"\n",
235+
" # add additional post-training summary metrics.\n",
236+
" with Live(\"results/evaluate\") as live:\n",
237+
" live.summary[\"dice_multi\"] = evaluate(learn)"
238+
]
239+
},
240+
{
241+
"cell_type": "code",
242+
"execution_count": null,
243+
"metadata": {},
244+
"outputs": [],
245+
"source": [
246+
"# Compare experiments\n",
247+
"!dvc exp show --only-changed"
248+
]
249+
},
250+
{
251+
"attachments": {},
252+
"cell_type": "markdown",
253+
"metadata": {},
254+
"source": [
255+
"### Review sample preditions vs ground truth\n",
256+
"\n",
257+
"Below are some example of the predicted masks."
258+
]
259+
},
260+
{
261+
"cell_type": "code",
262+
"execution_count": null,
263+
"metadata": {},
264+
"outputs": [],
265+
"source": [
266+
"learn.show_results(max_n=6, alpha=0.7)"
267+
]
268+
},
269+
{
270+
"cell_type": "code",
271+
"execution_count": null,
272+
"metadata": {},
273+
"outputs": [],
274+
"source": []
275+
}
276+
],
277+
"metadata": {
278+
"kernelspec": {
279+
"display_name": "Python 3 (ipykernel)",
280+
"language": "python",
281+
"name": "python3"
282+
},
283+
"language_info": {
284+
"codemirror_mode": {
285+
"name": "ipython",
286+
"version": 3
287+
},
288+
"file_extension": ".py",
289+
"mimetype": "text/x-python",
290+
"name": "python",
291+
"nbconvert_exporter": "python",
292+
"pygments_lexer": "ipython3",
293+
"version": "3.11.6"
294+
},
295+
"vscode": {
296+
"interpreter": {
297+
"hash": "949777d72b0d2535278d3dc13498b2535136f6dfe0678499012e853ee9abcab1"
298+
}
299+
}
300+
},
301+
"nbformat": 4,
302+
"nbformat_minor": 4
303+
}

0 commit comments

Comments
 (0)