Skip to content

Commit 3e9654f

Browse files
authored
Plot Method Added to GBIF Dataset (#2741)
* added plot method * added test for plot method
1 parent 4258bbc commit 3e9654f

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

tests/datasets/test_gbif.py

+8
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import os
55
from pathlib import Path
66

7+
import matplotlib.pyplot as plt
78
import pytest
9+
from matplotlib.figure import Figure
810

911
from torchgeo.datasets import (
1012
GBIF,
@@ -46,3 +48,9 @@ def test_invalid_query(self, dataset: GBIF) -> None:
4648
IndexError, match='query: .* not found in index with bounds:'
4749
):
4850
dataset[query]
51+
52+
def test_plot(self, dataset: GBIF) -> None:
53+
sample = dataset[dataset.bounds]
54+
fig = dataset.plot(sample, suptitle='test')
55+
assert isinstance(fig, Figure)
56+
plt.close()

torchgeo/datasets/gbif.py

+61
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
from datetime import datetime, timedelta
1010
from typing import Any
1111

12+
import matplotlib.pyplot as plt
1213
import numpy as np
1314
import pandas as pd
15+
from matplotlib.figure import Figure
16+
from matplotlib.ticker import FuncFormatter
1417
from rasterio.crs import CRS
1518

1619
from .errors import DatasetNotFoundError
@@ -140,3 +143,61 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
140143
sample = {'crs': self.crs, 'bounds': bboxes}
141144

142145
return sample
146+
147+
def plot(
148+
self,
149+
sample: dict[str, Any],
150+
show_titles: bool = True,
151+
suptitle: str | None = None,
152+
) -> Figure:
153+
"""Plot a sample from the dataset.
154+
155+
Args:
156+
sample: a sample return by :meth:`__getitem__`
157+
show_titles: flag indicating whether to show titles above each panel
158+
suptitle: optional suptitle to use for Figure
159+
Returns:
160+
a matplotlib Figure with the rendered sample
161+
162+
.. versionadded:: 0.8
163+
"""
164+
# Create figure and axis - using regular matplotlib axes
165+
fig, ax = plt.subplots(figsize=(10, 8))
166+
ax.grid(ls='--')
167+
168+
# Extract bounding boxes (coordinates) from the sample
169+
bboxes = sample['bounds']
170+
171+
# Extract coordinates and timestamps
172+
longitudes = [bbox[0] for bbox in bboxes] # minx
173+
latitudes = [bbox[1] for bbox in bboxes] # miny
174+
timestamps = [bbox[2] for bbox in bboxes] # mint
175+
176+
# Plot the points with colors based on date
177+
scatter = ax.scatter(longitudes, latitudes, c=timestamps, edgecolors='black')
178+
179+
# Create a formatter function
180+
def format_date(x: float, pos: int | None = None) -> str:
181+
# Convert timestamp to datetime
182+
return datetime.fromtimestamp(x).strftime('%Y-%m-%d')
183+
184+
# Add a colorbar
185+
cbar = fig.colorbar(scatter, ax=ax, pad=0.04)
186+
cbar.set_label('Observed Timestamp', rotation=90, labelpad=-100, va='center')
187+
188+
# Apply the formatter to the colorbar
189+
cbar.ax.yaxis.set_major_formatter(FuncFormatter(format_date))
190+
191+
# Set labels
192+
ax.set_xlabel('Longitude')
193+
ax.set_ylabel('Latitude')
194+
195+
# Add titles if requested
196+
if show_titles:
197+
ax.set_title('GBIF Occurence Locations by Date')
198+
199+
if suptitle is not None:
200+
fig.suptitle(suptitle)
201+
202+
fig.tight_layout()
203+
return fig

0 commit comments

Comments
 (0)