Skip to content

Commit c2fd8fd

Browse files
Update create_grid args to improve usability (#507)
* Refactors create_bounds from BoundsAccessor * Adds create_axis function * Deprecates create_grid's **kwargs and implements new x, y, z arguments * Fixes how create_grid creates the Dataset * Updates create_*_grid methods to use new create_grid * Fixes create_grid method signature * Removes old documentation * Adds proper deprecation notice to docstring * Updates vertical regrid example to use new create_grid * Apply suggestions from code review Co-authored-by: Tom Vo <tomvothecoder@gmail.com> * Fixes converting standard name to cf axis * Fixes formatting * Adds additional suggested fixes --------- Co-authored-by: Tom Vo <tomvothecoder@gmail.com>
1 parent b46c331 commit c2fd8fd

File tree

6 files changed

+574
-209
lines changed

6 files changed

+574
-209
lines changed

docs/api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ Below is a list of top-level API functions that are available in ``xcdat``.
2727
compare_datasets
2828
get_dim_coords
2929
get_dim_keys
30+
create_axis
3031
create_gaussian_grid
3132
create_global_mean_grid
3233
create_grid

docs/examples/regridding-vertical.ipynb

+101-72
Large diffs are not rendered by default.

tests/test_regrid.py

+146-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import datetime
2+
import re
23
import sys
4+
import warnings
35
from unittest import mock
46

57
import numpy as np
@@ -770,13 +772,144 @@ def test_preserve_bounds(self):
770772

771773

772774
class TestGrid:
775+
@pytest.fixture(autouse=True)
776+
def setUp(self):
777+
self.lat_data = np.array([-45, 0, 45])
778+
self.lat = xr.DataArray(self.lat_data.copy(), dims=["lat"], name="lat")
779+
780+
self.lat_bnds_data = np.array([[-67.5, -22.5], [-22.5, 22.5], [22.5, 67.5]])
781+
self.lat_bnds = xr.DataArray(
782+
self.lat_bnds_data.copy(), dims=["lat", "bnds"], name="lat_bnds"
783+
)
784+
785+
self.lon_data = np.array([30, 60, 90, 120, 150])
786+
self.lon = xr.DataArray(self.lon_data.copy(), dims=["lon"], name="lon")
787+
788+
self.lon_bnds_data = np.array(
789+
[[15, 45], [45, 75], [75, 105], [105, 135], [135, 165]]
790+
)
791+
self.lon_bnds = xr.DataArray(
792+
self.lon_bnds_data.copy(), dims=["lon", "bnds"], name="lon_bnds"
793+
)
794+
795+
def test_create_axis(self):
796+
expected_axis_attrs = {
797+
"axis": "Y",
798+
"units": "degrees_north",
799+
"coordinate": "latitude",
800+
"bounds": "lat_bnds",
801+
}
802+
803+
axis, bnds = grid.create_axis("lat", self.lat_data)
804+
805+
assert np.array_equal(axis, self.lat_data)
806+
assert bnds is not None
807+
assert bnds.attrs["xcdat_bounds"] == "True"
808+
assert axis.attrs == expected_axis_attrs
809+
810+
def test_create_axis_user_attrs(self):
811+
expected_axis_attrs = {
812+
"axis": "Y",
813+
"units": "degrees_south",
814+
"coordinate": "latitude",
815+
"bounds": "lat_bnds",
816+
"custom": "value",
817+
}
818+
819+
axis, bnds = grid.create_axis(
820+
"lat", self.lat_data, attrs={"custom": "value", "units": "degrees_south"}
821+
)
822+
823+
assert np.array_equal(axis, self.lat_data)
824+
assert bnds is not None
825+
assert bnds.attrs["xcdat_bounds"] == "True"
826+
assert axis.attrs == expected_axis_attrs
827+
828+
def test_create_axis_from_list(self):
829+
axis, bnds = grid.create_axis("lat", self.lat_data, bounds=self.lat_bnds_data)
830+
831+
assert np.array_equal(axis, self.lat_data)
832+
assert bnds is not None
833+
assert np.array_equal(bnds, self.lat_bnds_data)
834+
835+
def test_create_axis_no_bnds(self):
836+
expected_axis_attrs = {
837+
"axis": "Y",
838+
"units": "degrees_north",
839+
"coordinate": "latitude",
840+
}
841+
842+
axis, bnds = grid.create_axis("lat", self.lat_data, generate_bounds=False)
843+
844+
assert np.array_equal(axis, self.lat_data)
845+
assert bnds is None
846+
assert axis.attrs == expected_axis_attrs
847+
848+
def test_create_axis_user_bnds(self):
849+
expected_axis_attrs = {
850+
"axis": "Y",
851+
"units": "degrees_north",
852+
"coordinate": "latitude",
853+
"bounds": "lat_bnds",
854+
}
855+
856+
axis, bnds = grid.create_axis("lat", self.lat_data, bounds=self.lat_bnds_data)
857+
858+
assert np.array_equal(axis, self.lat_data)
859+
assert bnds is not None
860+
assert np.array_equal(bnds, self.lat_bnds_data)
861+
assert "xcdat_bounds" not in bnds.attrs
862+
assert axis.attrs == expected_axis_attrs
863+
864+
def test_create_axis_invalid_name(self):
865+
with pytest.raises(
866+
ValueError, match="The name 'mass' is not valid for an axis name."
867+
):
868+
grid.create_axis("mass", self.lat_data)
869+
773870
def test_empty_grid(self):
774871
with pytest.raises(
775-
ValueError, match="Must pass at least 1 coordinate to create a grid."
872+
ValueError, match="Must pass at least 1 axis to create a grid."
776873
):
777874
grid.create_grid()
778875

779-
def test_unexpected_coordinate(self):
876+
def test_create_grid(self):
877+
new_grid = grid.create_grid(x=self.lon, y=self.lat)
878+
879+
assert np.array_equal(new_grid.lat, self.lat)
880+
assert np.array_equal(new_grid.lon, self.lon)
881+
882+
def test_create_grid_with_bounds(self):
883+
new_grid = grid.create_grid(
884+
x=(self.lon, self.lon_bnds), y=(self.lat, self.lat_bnds)
885+
)
886+
887+
assert np.array_equal(new_grid.lat, self.lat)
888+
assert new_grid.lat.attrs["bounds"] == self.lat_bnds.name
889+
assert np.array_equal(new_grid.lat_bnds, self.lat_bnds)
890+
891+
assert np.array_equal(new_grid.lon, self.lon)
892+
assert new_grid.lon.attrs["bounds"] == self.lon_bnds.name
893+
assert np.array_equal(new_grid.lon_bnds, self.lon_bnds)
894+
895+
def test_create_grid_user_attrs(self):
896+
lev = xr.DataArray(np.linspace(1000, 1, 2), dims=["lev"], name="lev")
897+
898+
new_grid = grid.create_grid(z=lev, attrs={"custom": "value"})
899+
900+
assert "custom" in new_grid.attrs
901+
assert new_grid.attrs["custom"] == "value"
902+
903+
def test_create_grid_wrong_axis_value(self):
904+
with pytest.raises(
905+
ValueError,
906+
match=re.escape(
907+
"Argument 'x' should be an xr.DataArray representing coordinates or a tuple (xr.DataArray, xr.DataArray) representing coordinates and bounds."
908+
),
909+
):
910+
grid.create_grid(x=(self.lon, self.lon_bnds, self.lat)) # type: ignore[arg-type]
911+
912+
def test_deprecated_unexpected_coordinate(self):
780913
lev = np.linspace(1000, 1, 2)
781914

782915
with pytest.raises(
@@ -785,16 +918,24 @@ def test_unexpected_coordinate(self):
785918
):
786919
grid.create_grid(lev=lev, mass=np.linspace(10, 20, 2))
787920

788-
def test_create_grid_lev(self):
921+
def test_deprecated_create_grid_lev(self):
789922
lev = np.linspace(1000, 1, 2)
790923
lev_bnds = np.array([[1499.5, 500.5], [500.5, -498.5]])
791924

792-
new_grid = grid.create_grid(lev=(lev, lev_bnds))
925+
with warnings.catch_warnings(record=True) as w:
926+
new_grid = grid.create_grid(lev=(lev, lev_bnds))
927+
928+
assert len(w) == 1
929+
assert issubclass(w[0].category, DeprecationWarning)
930+
assert (
931+
str(w[0].message)
932+
== "**kwargs will be deprecated, see docstring and use 'x', 'y', or 'z' arguments"
933+
)
793934

794935
assert np.array_equal(new_grid.lev, lev)
795936
assert np.array_equal(new_grid.lev_bnds, lev_bnds)
796937

797-
def test_create_grid(self):
938+
def test_deprecated_create_grid(self):
798939
lat = np.array([-45, 0, 45])
799940
lon = np.array([30, 60, 90, 120, 150])
800941
lat_bnds = np.array([[-67.5, -22.5], [-22.5, 22.5], [22.5, 67.5]])

xcdat/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from xcdat.dataset import decode_time, open_dataset, open_mfdataset # noqa: F401
1010
from xcdat.regridder.accessor import RegridderAccessor # noqa: F401
1111
from xcdat.regridder.grid import ( # noqa: F401
12+
create_axis,
1213
create_gaussian_grid,
1314
create_global_mean_grid,
1415
create_grid,

0 commit comments

Comments
 (0)