15
15
import numpy as npo
16
16
import pydantic .v1 as pd
17
17
import xarray as xr
18
+ from numpy .typing import NDArray
18
19
from scipy import signal
19
20
20
21
from tidy3d .components .material .tcad .heat import ThermalSpecType
@@ -1383,15 +1384,14 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM
1383
1384
raise NotImplementedError (f"Can't compute derivative for 'Medium': '{ type (self )} '." )
1384
1385
1385
1386
def derivative_eps_sigma_volume (
1386
- self ,
1387
- E_der_map : ElectromagneticFieldDataset ,
1388
- bounds : Bound ,
1387
+ self , E_der_map : ElectromagneticFieldDataset , bounds : Bound , freqs : NDArray
1389
1388
) -> dict [str , xr .DataArray ]:
1390
1389
"""Get the derivative w.r.t permittivity and conductivity in the volume."""
1391
1390
1392
- vjp_eps_complex = self .derivative_eps_complex_volume (E_der_map = E_der_map , bounds = bounds )
1391
+ vjp_eps_complex = self .derivative_eps_complex_volume (
1392
+ E_der_map = E_der_map , bounds = bounds , freqs = freqs
1393
+ )
1393
1394
1394
- freqs = vjp_eps_complex .coords ["f" ].values
1395
1395
values = vjp_eps_complex .values
1396
1396
1397
1397
eps_vjp , sigma_vjp = self .eps_complex_to_eps_sigma (eps_complex = values , freq = freqs )
@@ -1402,13 +1402,13 @@ def derivative_eps_sigma_volume(
1402
1402
return dict (permittivity = eps_vjp , conductivity = sigma_vjp )
1403
1403
1404
1404
def derivative_eps_complex_volume (
1405
- self , E_der_map : ElectromagneticFieldDataset , bounds : Bound
1405
+ self , E_der_map : ElectromagneticFieldDataset , bounds : Bound , freqs : NDArray
1406
1406
) -> xr .DataArray :
1407
1407
"""Get the derivative w.r.t complex-valued permittivity in the volume."""
1408
1408
1409
1409
vjp_value = 0.0
1410
1410
for field_name in ("Ex" , "Ey" , "Ez" ):
1411
- fld = E_der_map [field_name ]
1411
+ fld = E_der_map [field_name ]. sel ( f = freqs )
1412
1412
vjp_value_fld = integrate_within_bounds (
1413
1413
arr = fld ,
1414
1414
dims = ("x" , "y" , "z" ),
@@ -1667,6 +1667,7 @@ def _derivative_field_cmp(
1667
1667
E_der_map : ElectromagneticFieldDataset ,
1668
1668
eps_data : PermittivityDataset ,
1669
1669
dim : str ,
1670
+ freqs : NDArray ,
1670
1671
) -> np .ndarray :
1671
1672
coords_interp = {key : val for key , val in eps_data .coords .items () if len (val ) > 1 }
1672
1673
dims_sum = {dim for dim in eps_data .coords .keys () if dim not in coords_interp }
@@ -1697,7 +1698,7 @@ def _derivative_field_cmp(
1697
1698
d_vol = np .array (1.0 )
1698
1699
1699
1700
# TODO: probably this could be more robust. eg if the DataArray has weird edge cases
1700
- E_der_dim = E_der_map [f"E{ dim } " ]
1701
+ E_der_dim = E_der_map [f"E{ dim } " ]. sel ( f = freqs )
1701
1702
E_der_dim_interp = (
1702
1703
E_der_dim .interp (** coords_interp , assume_sorted = True ).fillna (0.0 ).sum (dims_sum ).sum ("f" )
1703
1704
)
@@ -1919,7 +1920,9 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM
1919
1920
1920
1921
# get vjps w.r.t. permittivity and conductivity of the bulk
1921
1922
vjps_volume = self .derivative_eps_sigma_volume (
1922
- E_der_map = derivative_info .E_der_map , bounds = derivative_info .bounds
1923
+ E_der_map = derivative_info .E_der_map ,
1924
+ bounds = derivative_info .bounds ,
1925
+ freqs = np .atleast_1d (derivative_info .frequency ),
1923
1926
)
1924
1927
1925
1928
# store the fields asked for by ``field_paths``
@@ -1932,15 +1935,14 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM
1932
1935
return derivative_map
1933
1936
1934
1937
def derivative_eps_sigma_volume (
1935
- self ,
1936
- E_der_map : ElectromagneticFieldDataset ,
1937
- bounds : Bound ,
1938
+ self , E_der_map : ElectromagneticFieldDataset , bounds : Bound , freqs : NDArray
1938
1939
) -> dict [str , xr .DataArray ]:
1939
1940
"""Get the derivative w.r.t permittivity and conductivity in the volume."""
1940
1941
1941
- vjp_eps_complex = self .derivative_eps_complex_volume (E_der_map = E_der_map , bounds = bounds )
1942
+ vjp_eps_complex = self .derivative_eps_complex_volume (
1943
+ E_der_map = E_der_map , bounds = bounds , freqs = freqs
1944
+ )
1942
1945
1943
- freqs = vjp_eps_complex .coords ["f" ].values
1944
1946
values = vjp_eps_complex .values
1945
1947
1946
1948
# vjp of eps_complex_to_eps_sigma
@@ -1954,21 +1956,21 @@ def derivative_eps_sigma_volume(
1954
1956
return dict (permittivity = eps_vjp , conductivity = sigma_vjp )
1955
1957
1956
1958
def derivative_eps_complex_volume (
1957
- self , E_der_map : ElectromagneticFieldDataset , bounds : Bound
1959
+ self , E_der_map : ElectromagneticFieldDataset , bounds : Bound , freqs : NDArray
1958
1960
) -> xr .DataArray :
1959
1961
"""Get the derivative w.r.t complex-valued permittivity in the volume."""
1960
1962
1961
1963
vjp_value = 0.0
1962
1964
for field_name in ("Ex" , "Ey" , "Ez" ):
1963
- fld = E_der_map [field_name ]
1965
+ fld = E_der_map [field_name ]. sel ( f = freqs )
1964
1966
vjp_value_fld = integrate_within_bounds (
1965
1967
arr = fld ,
1966
1968
dims = ("x" , "y" , "z" ),
1967
1969
bounds = bounds ,
1968
1970
)
1969
1971
vjp_value += vjp_value_fld
1970
1972
1971
- return vjp_value
1973
+ return vjp_value . sum ( "f" )
1972
1974
1973
1975
1974
1976
class CustomIsotropicMedium (AbstractCustomMedium , Medium ):
@@ -2837,7 +2839,10 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM
2837
2839
vjp_array = 0.0
2838
2840
for dim in "xyz" :
2839
2841
vjp_array += self ._derivative_field_cmp (
2840
- E_der_map = derivative_info .E_der_map , eps_data = self .permittivity , dim = dim
2842
+ E_der_map = derivative_info .E_der_map ,
2843
+ eps_data = self .permittivity ,
2844
+ dim = dim ,
2845
+ freqs = np .atleast_1d (derivative_info .frequency ),
2841
2846
)
2842
2847
vjps [field_path ] = vjp_array
2843
2848
@@ -2848,6 +2853,7 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM
2848
2853
E_der_map = derivative_info .E_der_map ,
2849
2854
eps_data = self .eps_dataset .field_components [key ],
2850
2855
dim = dim ,
2856
+ freqs = np .atleast_1d (derivative_info .frequency ),
2851
2857
)
2852
2858
2853
2859
else :
@@ -2862,13 +2868,14 @@ def _derivative_field_cmp(
2862
2868
E_der_map : ElectromagneticFieldDataset ,
2863
2869
eps_data : PermittivityDataset ,
2864
2870
dim : str ,
2871
+ freqs : NDArray ,
2865
2872
) -> np .ndarray :
2866
2873
"""Compute derivative with respect to the ``dim`` components within the custom medium."""
2867
2874
2868
2875
coords_interp = {key : eps_data .coords [key ] for key in "xyz" }
2869
2876
coords_interp = {key : val for key , val in coords_interp .items () if len (val ) > 1 }
2870
2877
2871
- E_der_dim_interp = E_der_map [f"E{ dim } " ]
2878
+ E_der_dim_interp = E_der_map [f"E{ dim } " ]. sel ( f = freqs )
2872
2879
2873
2880
for dim_ in "xyz" :
2874
2881
if dim_ not in coords_interp :
@@ -3420,7 +3427,9 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM
3420
3427
3421
3428
# compute all derivatives beforehand
3422
3429
dJ_deps = self .derivative_eps_complex_volume (
3423
- E_der_map = derivative_info .E_der_map , bounds = derivative_info .bounds
3430
+ E_der_map = derivative_info .E_der_map ,
3431
+ bounds = derivative_info .bounds ,
3432
+ freqs = np .atleast_1d (derivative_info .frequency ),
3424
3433
)
3425
3434
3426
3435
dJ_deps = complex (dJ_deps )
@@ -3896,7 +3905,10 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM
3896
3905
dJ_deps = 0.0
3897
3906
for dim in "xyz" :
3898
3907
dJ_deps += self ._derivative_field_cmp (
3899
- E_der_map = derivative_info .E_der_map , eps_data = self .eps_inf , dim = dim
3908
+ E_der_map = derivative_info .E_der_map ,
3909
+ eps_data = self .eps_inf ,
3910
+ dim = dim ,
3911
+ freqs = np .atleast_1d (derivative_info .frequency ),
3900
3912
)
3901
3913
3902
3914
# TODO: fix for multi-frequency
0 commit comments