Skip to content

Commit 7b16189

Browse files
committed
feat: Sanity Checking: Use a trained model for predictions
Use matplotlib to plot the probabilities and the input image.
1 parent e798869 commit 7b16189

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

Image Classifier Project.ipynb

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1414,7 +1414,34 @@
14141414
},
14151415
"cell_type": "code",
14161416
"source": [
1417-
"# TODO: Display an image along with the top 5 classes"
1417+
"# Display an image along with the top 5 classes\n",
1418+
"probs, classes = predict(image_path, model)\n",
1419+
"\n",
1420+
"probs = probs.data.numpy().squeeze()\n",
1421+
"classes = classes.data.numpy().squeeze()\n",
1422+
"classes = [cat_label_to_name[clazz].title() for clazz in classes]\n",
1423+
"\n",
1424+
"label = class_to_idx[str(category)]\n",
1425+
"title = f'{cat_label_to_name[label].title()}'\n",
1426+
"\n",
1427+
"fig = plt.figure(figsize=(4, 10))\n",
1428+
"\n",
1429+
"ax1 = fig.add_subplot(2, 1, 1, xticks=[], yticks=[])\n",
1430+
"\n",
1431+
"image = Image.open(image_path)\n",
1432+
"image = process_image(image)\n",
1433+
"imshow(image, ax1, title)\n",
1434+
"\n",
1435+
"ax2 = fig.add_subplot(2, 1, 2, xticks=[], yticks=[])\n",
1436+
"ax2.barh(np.arange(5), probs)\n",
1437+
"ax2.set_yticks(np.arange(5))\n",
1438+
"ax2.set_yticklabels(classes)\n",
1439+
"ax2.set_ylim(-1, 5)\n",
1440+
"ax2.invert_yaxis()\n",
1441+
"ax2.set_xlim(0, 1.1)\n",
1442+
"ax2.set_title('Class Probability')\n",
1443+
"\n",
1444+
"plt.show()"
14181445
],
14191446
"execution_count": 0,
14201447
"outputs": []

0 commit comments

Comments
 (0)