1
1
import datetime
2
+ import re
2
3
import sys
4
+ import warnings
3
5
from unittest import mock
4
6
5
7
import numpy as np
@@ -770,13 +772,144 @@ def test_preserve_bounds(self):
770
772
771
773
772
774
class TestGrid :
775
+ @pytest .fixture (autouse = True )
776
+ def setUp (self ):
777
+ self .lat_data = np .array ([- 45 , 0 , 45 ])
778
+ self .lat = xr .DataArray (self .lat_data .copy (), dims = ["lat" ], name = "lat" )
779
+
780
+ self .lat_bnds_data = np .array ([[- 67.5 , - 22.5 ], [- 22.5 , 22.5 ], [22.5 , 67.5 ]])
781
+ self .lat_bnds = xr .DataArray (
782
+ self .lat_bnds_data .copy (), dims = ["lat" , "bnds" ], name = "lat_bnds"
783
+ )
784
+
785
+ self .lon_data = np .array ([30 , 60 , 90 , 120 , 150 ])
786
+ self .lon = xr .DataArray (self .lon_data .copy (), dims = ["lon" ], name = "lon" )
787
+
788
+ self .lon_bnds_data = np .array (
789
+ [[15 , 45 ], [45 , 75 ], [75 , 105 ], [105 , 135 ], [135 , 165 ]]
790
+ )
791
+ self .lon_bnds = xr .DataArray (
792
+ self .lon_bnds_data .copy (), dims = ["lon" , "bnds" ], name = "lon_bnds"
793
+ )
794
+
795
+ def test_create_axis (self ):
796
+ expected_axis_attrs = {
797
+ "axis" : "Y" ,
798
+ "units" : "degrees_north" ,
799
+ "coordinate" : "latitude" ,
800
+ "bounds" : "lat_bnds" ,
801
+ }
802
+
803
+ axis , bnds = grid .create_axis ("lat" , self .lat_data )
804
+
805
+ assert np .array_equal (axis , self .lat_data )
806
+ assert bnds is not None
807
+ assert bnds .attrs ["xcdat_bounds" ] == "True"
808
+ assert axis .attrs == expected_axis_attrs
809
+
810
+ def test_create_axis_user_attrs (self ):
811
+ expected_axis_attrs = {
812
+ "axis" : "Y" ,
813
+ "units" : "degrees_south" ,
814
+ "coordinate" : "latitude" ,
815
+ "bounds" : "lat_bnds" ,
816
+ "custom" : "value" ,
817
+ }
818
+
819
+ axis , bnds = grid .create_axis (
820
+ "lat" , self .lat_data , attrs = {"custom" : "value" , "units" : "degrees_south" }
821
+ )
822
+
823
+ assert np .array_equal (axis , self .lat_data )
824
+ assert bnds is not None
825
+ assert bnds .attrs ["xcdat_bounds" ] == "True"
826
+ assert axis .attrs == expected_axis_attrs
827
+
828
+ def test_create_axis_from_list (self ):
829
+ axis , bnds = grid .create_axis ("lat" , self .lat_data , bounds = self .lat_bnds_data )
830
+
831
+ assert np .array_equal (axis , self .lat_data )
832
+ assert bnds is not None
833
+ assert np .array_equal (bnds , self .lat_bnds_data )
834
+
835
+ def test_create_axis_no_bnds (self ):
836
+ expected_axis_attrs = {
837
+ "axis" : "Y" ,
838
+ "units" : "degrees_north" ,
839
+ "coordinate" : "latitude" ,
840
+ }
841
+
842
+ axis , bnds = grid .create_axis ("lat" , self .lat_data , generate_bounds = False )
843
+
844
+ assert np .array_equal (axis , self .lat_data )
845
+ assert bnds is None
846
+ assert axis .attrs == expected_axis_attrs
847
+
848
+ def test_create_axis_user_bnds (self ):
849
+ expected_axis_attrs = {
850
+ "axis" : "Y" ,
851
+ "units" : "degrees_north" ,
852
+ "coordinate" : "latitude" ,
853
+ "bounds" : "lat_bnds" ,
854
+ }
855
+
856
+ axis , bnds = grid .create_axis ("lat" , self .lat_data , bounds = self .lat_bnds_data )
857
+
858
+ assert np .array_equal (axis , self .lat_data )
859
+ assert bnds is not None
860
+ assert np .array_equal (bnds , self .lat_bnds_data )
861
+ assert "xcdat_bounds" not in bnds .attrs
862
+ assert axis .attrs == expected_axis_attrs
863
+
864
+ def test_create_axis_invalid_name (self ):
865
+ with pytest .raises (
866
+ ValueError , match = "The name 'mass' is not valid for an axis name."
867
+ ):
868
+ grid .create_axis ("mass" , self .lat_data )
869
+
773
870
def test_empty_grid (self ):
774
871
with pytest .raises (
775
- ValueError , match = "Must pass at least 1 coordinate to create a grid."
872
+ ValueError , match = "Must pass at least 1 axis to create a grid."
776
873
):
777
874
grid .create_grid ()
778
875
779
- def test_unexpected_coordinate (self ):
876
+ def test_create_grid (self ):
877
+ new_grid = grid .create_grid (x = self .lon , y = self .lat )
878
+
879
+ assert np .array_equal (new_grid .lat , self .lat )
880
+ assert np .array_equal (new_grid .lon , self .lon )
881
+
882
+ def test_create_grid_with_bounds (self ):
883
+ new_grid = grid .create_grid (
884
+ x = (self .lon , self .lon_bnds ), y = (self .lat , self .lat_bnds )
885
+ )
886
+
887
+ assert np .array_equal (new_grid .lat , self .lat )
888
+ assert new_grid .lat .attrs ["bounds" ] == self .lat_bnds .name
889
+ assert np .array_equal (new_grid .lat_bnds , self .lat_bnds )
890
+
891
+ assert np .array_equal (new_grid .lon , self .lon )
892
+ assert new_grid .lon .attrs ["bounds" ] == self .lon_bnds .name
893
+ assert np .array_equal (new_grid .lon_bnds , self .lon_bnds )
894
+
895
+ def test_create_grid_user_attrs (self ):
896
+ lev = xr .DataArray (np .linspace (1000 , 1 , 2 ), dims = ["lev" ], name = "lev" )
897
+
898
+ new_grid = grid .create_grid (z = lev , attrs = {"custom" : "value" })
899
+
900
+ assert "custom" in new_grid .attrs
901
+ assert new_grid .attrs ["custom" ] == "value"
902
+
903
+ def test_create_grid_wrong_axis_value (self ):
904
+ with pytest .raises (
905
+ ValueError ,
906
+ match = re .escape (
907
+ "Argument 'x' should be an xr.DataArray representing coordinates or a tuple (xr.DataArray, xr.DataArray) representing coordinates and bounds."
908
+ ),
909
+ ):
910
+ grid .create_grid (x = (self .lon , self .lon_bnds , self .lat )) # type: ignore[arg-type]
911
+
912
+ def test_deprecated_unexpected_coordinate (self ):
780
913
lev = np .linspace (1000 , 1 , 2 )
781
914
782
915
with pytest .raises (
@@ -785,16 +918,24 @@ def test_unexpected_coordinate(self):
785
918
):
786
919
grid .create_grid (lev = lev , mass = np .linspace (10 , 20 , 2 ))
787
920
788
- def test_create_grid_lev (self ):
921
+ def test_deprecated_create_grid_lev (self ):
789
922
lev = np .linspace (1000 , 1 , 2 )
790
923
lev_bnds = np .array ([[1499.5 , 500.5 ], [500.5 , - 498.5 ]])
791
924
792
- new_grid = grid .create_grid (lev = (lev , lev_bnds ))
925
+ with warnings .catch_warnings (record = True ) as w :
926
+ new_grid = grid .create_grid (lev = (lev , lev_bnds ))
927
+
928
+ assert len (w ) == 1
929
+ assert issubclass (w [0 ].category , DeprecationWarning )
930
+ assert (
931
+ str (w [0 ].message )
932
+ == "**kwargs will be deprecated, see docstring and use 'x', 'y', or 'z' arguments"
933
+ )
793
934
794
935
assert np .array_equal (new_grid .lev , lev )
795
936
assert np .array_equal (new_grid .lev_bnds , lev_bnds )
796
937
797
- def test_create_grid (self ):
938
+ def test_deprecated_create_grid (self ):
798
939
lat = np .array ([- 45 , 0 , 45 ])
799
940
lon = np .array ([30 , 60 , 90 , 120 , 150 ])
800
941
lat_bnds = np .array ([[- 67.5 , - 22.5 ], [- 22.5 , 22.5 ], [22.5 , 67.5 ]])
0 commit comments