|
1414 | 1414 | },
|
1415 | 1415 | "cell_type": "code",
|
1416 | 1416 | "source": [
|
1417 |
| - "# TODO: Display an image along with the top 5 classes" |
| 1417 | + "# Display an image along with the top 5 classes\n", |
| 1418 | + "probs, classes = predict(image_path, model)\n", |
| 1419 | + "\n", |
| 1420 | + "probs = probs.data.numpy().squeeze()\n", |
| 1421 | + "classes = classes.data.numpy().squeeze()\n", |
| 1422 | + "classes = [cat_label_to_name[clazz].title() for clazz in classes]\n", |
| 1423 | + "\n", |
| 1424 | + "label = class_to_idx[str(category)]\n", |
| 1425 | + "title = f'{cat_label_to_name[label].title()}'\n", |
| 1426 | + "\n", |
| 1427 | + "fig = plt.figure(figsize=(4, 10))\n", |
| 1428 | + "\n", |
| 1429 | + "ax1 = fig.add_subplot(2, 1, 1, xticks=[], yticks=[])\n", |
| 1430 | + "\n", |
| 1431 | + "image = Image.open(image_path)\n", |
| 1432 | + "image = process_image(image)\n", |
| 1433 | + "imshow(image, ax1, title)\n", |
| 1434 | + "\n", |
| 1435 | + "ax2 = fig.add_subplot(2, 1, 2, xticks=[], yticks=[])\n", |
| 1436 | + "ax2.barh(np.arange(5), probs)\n", |
| 1437 | + "ax2.set_yticks(np.arange(5))\n", |
| 1438 | + "ax2.set_yticklabels(classes)\n", |
| 1439 | + "ax2.set_ylim(-1, 5)\n", |
| 1440 | + "ax2.invert_yaxis()\n", |
| 1441 | + "ax2.set_xlim(0, 1.1)\n", |
| 1442 | + "ax2.set_title('Class Probability')\n", |
| 1443 | + "\n", |
| 1444 | + "plt.show()" |
1418 | 1445 | ],
|
1419 | 1446 | "execution_count": 0,
|
1420 | 1447 | "outputs": []
|
|
0 commit comments