diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b66f1b85f..e601052b14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - EME now supports 2D simulations. - 'EMESimulation' now supports 'PermittivityMonitor'. +### Fixed +- Fixed issue with `CustomMedium` gradients where other frequencies would wrongly contribute to the gradient. + ## [2.8.3] - 2025-04-24 ### Added diff --git a/tests/test_components/test_autograd.py b/tests/test_components/test_autograd.py index c5297eb76d..a4d582ced5 100644 --- a/tests/test_components/test_autograd.py +++ b/tests/test_components/test_autograd.py @@ -1433,7 +1433,9 @@ def J(eps): dJ_deps = ag.holomorphic_grad(J)(eps0) monkeypatch.setattr( - td.PoleResidue, "derivative_eps_complex_volume", lambda self, E_der_map, bounds: dJ_deps + td.PoleResidue, + "derivative_eps_complex_volume", + lambda self, E_der_map, bounds, freqs: dJ_deps, ) import importlib @@ -1516,7 +1518,7 @@ def J(eps): monkeypatch.setattr( td.CustomPoleResidue, "_derivative_field_cmp", - lambda self, E_der_map, eps_data, dim: dJ_deps, + lambda self, E_der_map, eps_data, dim, freqs: dJ_deps, ) import importlib diff --git a/tidy3d/components/medium.py b/tidy3d/components/medium.py index f545b5e2e2..1383ce93bf 100644 --- a/tidy3d/components/medium.py +++ b/tidy3d/components/medium.py @@ -15,6 +15,7 @@ import numpy as npo import pydantic.v1 as pd import xarray as xr +from numpy.typing import NDArray from scipy import signal from tidy3d.components.material.tcad.heat import ThermalSpecType @@ -1383,15 +1384,14 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM raise NotImplementedError(f"Can't compute derivative for 'Medium': '{type(self)}'.") def derivative_eps_sigma_volume( - self, - E_der_map: ElectromagneticFieldDataset, - bounds: Bound, + self, E_der_map: ElectromagneticFieldDataset, bounds: Bound, freqs: NDArray ) -> dict[str, xr.DataArray]: """Get the derivative w.r.t permittivity and conductivity in the volume.""" - vjp_eps_complex = self.derivative_eps_complex_volume(E_der_map=E_der_map, bounds=bounds) + vjp_eps_complex = self.derivative_eps_complex_volume( + E_der_map=E_der_map, bounds=bounds, freqs=freqs + ) - freqs = vjp_eps_complex.coords["f"].values values = vjp_eps_complex.values eps_vjp, sigma_vjp = self.eps_complex_to_eps_sigma(eps_complex=values, freq=freqs) @@ -1402,13 +1402,13 @@ def derivative_eps_sigma_volume( return dict(permittivity=eps_vjp, conductivity=sigma_vjp) def derivative_eps_complex_volume( - self, E_der_map: ElectromagneticFieldDataset, bounds: Bound + self, E_der_map: ElectromagneticFieldDataset, bounds: Bound, freqs: NDArray ) -> xr.DataArray: """Get the derivative w.r.t complex-valued permittivity in the volume.""" vjp_value = 0.0 for field_name in ("Ex", "Ey", "Ez"): - fld = E_der_map[field_name] + fld = E_der_map[field_name].sel(f=freqs) vjp_value_fld = integrate_within_bounds( arr=fld, dims=("x", "y", "z"), @@ -1667,6 +1667,7 @@ def _derivative_field_cmp( E_der_map: ElectromagneticFieldDataset, eps_data: PermittivityDataset, dim: str, + freqs: NDArray, ) -> np.ndarray: coords_interp = {key: val for key, val in eps_data.coords.items() if len(val) > 1} dims_sum = {dim for dim in eps_data.coords.keys() if dim not in coords_interp} @@ -1697,7 +1698,7 @@ def _derivative_field_cmp( d_vol = np.array(1.0) # TODO: probably this could be more robust. eg if the DataArray has weird edge cases - E_der_dim = E_der_map[f"E{dim}"] + E_der_dim = E_der_map[f"E{dim}"].sel(f=freqs) E_der_dim_interp = ( E_der_dim.interp(**coords_interp, assume_sorted=True).fillna(0.0).sum(dims_sum).sum("f") ) @@ -1919,7 +1920,9 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM # get vjps w.r.t. permittivity and conductivity of the bulk vjps_volume = self.derivative_eps_sigma_volume( - E_der_map=derivative_info.E_der_map, bounds=derivative_info.bounds + E_der_map=derivative_info.E_der_map, + bounds=derivative_info.bounds, + freqs=np.atleast_1d(derivative_info.frequency), ) # store the fields asked for by ``field_paths`` @@ -1932,15 +1935,14 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM return derivative_map def derivative_eps_sigma_volume( - self, - E_der_map: ElectromagneticFieldDataset, - bounds: Bound, + self, E_der_map: ElectromagneticFieldDataset, bounds: Bound, freqs: NDArray ) -> dict[str, xr.DataArray]: """Get the derivative w.r.t permittivity and conductivity in the volume.""" - vjp_eps_complex = self.derivative_eps_complex_volume(E_der_map=E_der_map, bounds=bounds) + vjp_eps_complex = self.derivative_eps_complex_volume( + E_der_map=E_der_map, bounds=bounds, freqs=freqs + ) - freqs = vjp_eps_complex.coords["f"].values values = vjp_eps_complex.values # vjp of eps_complex_to_eps_sigma @@ -1954,13 +1956,13 @@ def derivative_eps_sigma_volume( return dict(permittivity=eps_vjp, conductivity=sigma_vjp) def derivative_eps_complex_volume( - self, E_der_map: ElectromagneticFieldDataset, bounds: Bound + self, E_der_map: ElectromagneticFieldDataset, bounds: Bound, freqs: NDArray ) -> xr.DataArray: """Get the derivative w.r.t complex-valued permittivity in the volume.""" vjp_value = 0.0 for field_name in ("Ex", "Ey", "Ez"): - fld = E_der_map[field_name] + fld = E_der_map[field_name].sel(f=freqs) vjp_value_fld = integrate_within_bounds( arr=fld, dims=("x", "y", "z"), @@ -1968,7 +1970,7 @@ def derivative_eps_complex_volume( ) vjp_value += vjp_value_fld - return vjp_value + return vjp_value.sum("f") class CustomIsotropicMedium(AbstractCustomMedium, Medium): @@ -2837,7 +2839,10 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM vjp_array = 0.0 for dim in "xyz": vjp_array += self._derivative_field_cmp( - E_der_map=derivative_info.E_der_map, eps_data=self.permittivity, dim=dim + E_der_map=derivative_info.E_der_map, + eps_data=self.permittivity, + dim=dim, + freqs=np.atleast_1d(derivative_info.frequency), ) vjps[field_path] = vjp_array @@ -2848,6 +2853,7 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM E_der_map=derivative_info.E_der_map, eps_data=self.eps_dataset.field_components[key], dim=dim, + freqs=np.atleast_1d(derivative_info.frequency), ) else: @@ -2862,13 +2868,14 @@ def _derivative_field_cmp( E_der_map: ElectromagneticFieldDataset, eps_data: PermittivityDataset, dim: str, + freqs: NDArray, ) -> np.ndarray: """Compute derivative with respect to the ``dim`` components within the custom medium.""" coords_interp = {key: eps_data.coords[key] for key in "xyz"} coords_interp = {key: val for key, val in coords_interp.items() if len(val) > 1} - E_der_dim_interp = E_der_map[f"E{dim}"] + E_der_dim_interp = E_der_map[f"E{dim}"].sel(f=freqs) for dim_ in "xyz": if dim_ not in coords_interp: @@ -3420,7 +3427,9 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM # compute all derivatives beforehand dJ_deps = self.derivative_eps_complex_volume( - E_der_map=derivative_info.E_der_map, bounds=derivative_info.bounds + E_der_map=derivative_info.E_der_map, + bounds=derivative_info.bounds, + freqs=np.atleast_1d(derivative_info.frequency), ) dJ_deps = complex(dJ_deps) @@ -3896,7 +3905,10 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM dJ_deps = 0.0 for dim in "xyz": dJ_deps += self._derivative_field_cmp( - E_der_map=derivative_info.E_der_map, eps_data=self.eps_inf, dim=dim + E_der_map=derivative_info.E_der_map, + eps_data=self.eps_inf, + dim=dim, + freqs=np.atleast_1d(derivative_info.frequency), ) # TODO: fix for multi-frequency