Skip to content

fix[autograd]: remove frequency summing in CustomMedium gradient and … #2430

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- EME now supports 2D simulations.
- 'EMESimulation' now supports 'PermittivityMonitor'.

### Fixed
- Fixed issue with `CustomMedium` gradients where other frequencies would wrongly contribute to the gradient.

## [2.8.3] - 2025-04-24

### Added
Expand Down
6 changes: 4 additions & 2 deletions tests/test_components/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1433,7 +1433,9 @@ def J(eps):
dJ_deps = ag.holomorphic_grad(J)(eps0)

monkeypatch.setattr(
td.PoleResidue, "derivative_eps_complex_volume", lambda self, E_der_map, bounds: dJ_deps
td.PoleResidue,
"derivative_eps_complex_volume",
lambda self, E_der_map, bounds, freqs: dJ_deps,
)

import importlib
Expand Down Expand Up @@ -1516,7 +1518,7 @@ def J(eps):
monkeypatch.setattr(
td.CustomPoleResidue,
"_derivative_field_cmp",
lambda self, E_der_map, eps_data, dim: dJ_deps,
lambda self, E_der_map, eps_data, dim, freqs: dJ_deps,
)

import importlib
Expand Down
54 changes: 33 additions & 21 deletions tidy3d/components/medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as npo
import pydantic.v1 as pd
import xarray as xr
from numpy.typing import NDArray
from scipy import signal

from tidy3d.components.material.tcad.heat import ThermalSpecType
Expand Down Expand Up @@ -1383,15 +1384,14 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM
raise NotImplementedError(f"Can't compute derivative for 'Medium': '{type(self)}'.")

def derivative_eps_sigma_volume(
self,
E_der_map: ElectromagneticFieldDataset,
bounds: Bound,
self, E_der_map: ElectromagneticFieldDataset, bounds: Bound, freqs: NDArray
) -> dict[str, xr.DataArray]:
"""Get the derivative w.r.t permittivity and conductivity in the volume."""

vjp_eps_complex = self.derivative_eps_complex_volume(E_der_map=E_der_map, bounds=bounds)
vjp_eps_complex = self.derivative_eps_complex_volume(
E_der_map=E_der_map, bounds=bounds, freqs=freqs
)

freqs = vjp_eps_complex.coords["f"].values
values = vjp_eps_complex.values

eps_vjp, sigma_vjp = self.eps_complex_to_eps_sigma(eps_complex=values, freq=freqs)
Expand All @@ -1402,13 +1402,13 @@ def derivative_eps_sigma_volume(
return dict(permittivity=eps_vjp, conductivity=sigma_vjp)

def derivative_eps_complex_volume(
self, E_der_map: ElectromagneticFieldDataset, bounds: Bound
self, E_der_map: ElectromagneticFieldDataset, bounds: Bound, freqs: NDArray
) -> xr.DataArray:
"""Get the derivative w.r.t complex-valued permittivity in the volume."""

vjp_value = 0.0
for field_name in ("Ex", "Ey", "Ez"):
fld = E_der_map[field_name]
fld = E_der_map[field_name].sel(f=freqs)
vjp_value_fld = integrate_within_bounds(
arr=fld,
dims=("x", "y", "z"),
Expand Down Expand Up @@ -1667,6 +1667,7 @@ def _derivative_field_cmp(
E_der_map: ElectromagneticFieldDataset,
eps_data: PermittivityDataset,
dim: str,
freqs: NDArray,
) -> np.ndarray:
coords_interp = {key: val for key, val in eps_data.coords.items() if len(val) > 1}
dims_sum = {dim for dim in eps_data.coords.keys() if dim not in coords_interp}
Expand Down Expand Up @@ -1697,7 +1698,7 @@ def _derivative_field_cmp(
d_vol = np.array(1.0)

# TODO: probably this could be more robust. eg if the DataArray has weird edge cases
E_der_dim = E_der_map[f"E{dim}"]
E_der_dim = E_der_map[f"E{dim}"].sel(f=freqs)
E_der_dim_interp = (
E_der_dim.interp(**coords_interp, assume_sorted=True).fillna(0.0).sum(dims_sum).sum("f")
)
Expand Down Expand Up @@ -1919,7 +1920,9 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM

# get vjps w.r.t. permittivity and conductivity of the bulk
vjps_volume = self.derivative_eps_sigma_volume(
E_der_map=derivative_info.E_der_map, bounds=derivative_info.bounds
E_der_map=derivative_info.E_der_map,
bounds=derivative_info.bounds,
freqs=np.atleast_1d(derivative_info.frequency),
)

# store the fields asked for by ``field_paths``
Expand All @@ -1932,15 +1935,14 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM
return derivative_map

def derivative_eps_sigma_volume(
self,
E_der_map: ElectromagneticFieldDataset,
bounds: Bound,
self, E_der_map: ElectromagneticFieldDataset, bounds: Bound, freqs: NDArray
) -> dict[str, xr.DataArray]:
"""Get the derivative w.r.t permittivity and conductivity in the volume."""

vjp_eps_complex = self.derivative_eps_complex_volume(E_der_map=E_der_map, bounds=bounds)
vjp_eps_complex = self.derivative_eps_complex_volume(
E_der_map=E_der_map, bounds=bounds, freqs=freqs
)

freqs = vjp_eps_complex.coords["f"].values
values = vjp_eps_complex.values

# vjp of eps_complex_to_eps_sigma
Expand All @@ -1954,21 +1956,21 @@ def derivative_eps_sigma_volume(
return dict(permittivity=eps_vjp, conductivity=sigma_vjp)

def derivative_eps_complex_volume(
self, E_der_map: ElectromagneticFieldDataset, bounds: Bound
self, E_der_map: ElectromagneticFieldDataset, bounds: Bound, freqs: NDArray
) -> xr.DataArray:
"""Get the derivative w.r.t complex-valued permittivity in the volume."""

vjp_value = 0.0
for field_name in ("Ex", "Ey", "Ez"):
fld = E_der_map[field_name]
fld = E_der_map[field_name].sel(f=freqs)
vjp_value_fld = integrate_within_bounds(
arr=fld,
dims=("x", "y", "z"),
bounds=bounds,
)
vjp_value += vjp_value_fld

return vjp_value
return vjp_value.sum("f")


class CustomIsotropicMedium(AbstractCustomMedium, Medium):
Expand Down Expand Up @@ -2837,7 +2839,10 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM
vjp_array = 0.0
for dim in "xyz":
vjp_array += self._derivative_field_cmp(
E_der_map=derivative_info.E_der_map, eps_data=self.permittivity, dim=dim
E_der_map=derivative_info.E_der_map,
eps_data=self.permittivity,
dim=dim,
freqs=np.atleast_1d(derivative_info.frequency),
)
vjps[field_path] = vjp_array

Expand All @@ -2848,6 +2853,7 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM
E_der_map=derivative_info.E_der_map,
eps_data=self.eps_dataset.field_components[key],
dim=dim,
freqs=np.atleast_1d(derivative_info.frequency),
)

else:
Expand All @@ -2862,13 +2868,14 @@ def _derivative_field_cmp(
E_der_map: ElectromagneticFieldDataset,
eps_data: PermittivityDataset,
dim: str,
freqs: NDArray,
) -> np.ndarray:
"""Compute derivative with respect to the ``dim`` components within the custom medium."""

coords_interp = {key: eps_data.coords[key] for key in "xyz"}
coords_interp = {key: val for key, val in coords_interp.items() if len(val) > 1}

E_der_dim_interp = E_der_map[f"E{dim}"]
E_der_dim_interp = E_der_map[f"E{dim}"].sel(f=freqs)

for dim_ in "xyz":
if dim_ not in coords_interp:
Expand Down Expand Up @@ -3420,7 +3427,9 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM

# compute all derivatives beforehand
dJ_deps = self.derivative_eps_complex_volume(
E_der_map=derivative_info.E_der_map, bounds=derivative_info.bounds
E_der_map=derivative_info.E_der_map,
bounds=derivative_info.bounds,
freqs=np.atleast_1d(derivative_info.frequency),
)

dJ_deps = complex(dJ_deps)
Expand Down Expand Up @@ -3896,7 +3905,10 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM
dJ_deps = 0.0
for dim in "xyz":
dJ_deps += self._derivative_field_cmp(
E_der_map=derivative_info.E_der_map, eps_data=self.eps_inf, dim=dim
E_der_map=derivative_info.E_der_map,
eps_data=self.eps_inf,
dim=dim,
freqs=np.atleast_1d(derivative_info.frequency),
)

# TODO: fix for multi-frequency
Expand Down