Skip to content

Commit 14cb927

Browse files
Gregory RobertsGregory Roberts
Gregory Roberts
authored and
Gregory Roberts
committed
fix[autograd]: remove frequency summing in CustomMedium gradient and select adjoint frequency
1 parent 936a88e commit 14cb927

File tree

3 files changed

+40
-23
lines changed

3 files changed

+40
-23
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2121
- EME now supports 2D simulations.
2222
- 'EMESimulation' now supports 'PermittivityMonitor'.
2323

24+
### Fixed
25+
- Fixed issue with `CustomMedium` gradients where other frequencies would wrongly contribute to the gradient.
26+
2427
## [2.8.3] - 2025-04-24
2528

2629
### Added

tests/test_components/test_autograd.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1433,7 +1433,9 @@ def J(eps):
14331433
dJ_deps = ag.holomorphic_grad(J)(eps0)
14341434

14351435
monkeypatch.setattr(
1436-
td.PoleResidue, "derivative_eps_complex_volume", lambda self, E_der_map, bounds: dJ_deps
1436+
td.PoleResidue,
1437+
"derivative_eps_complex_volume",
1438+
lambda self, E_der_map, bounds, freqs: dJ_deps,
14371439
)
14381440

14391441
import importlib
@@ -1516,7 +1518,7 @@ def J(eps):
15161518
monkeypatch.setattr(
15171519
td.CustomPoleResidue,
15181520
"_derivative_field_cmp",
1519-
lambda self, E_der_map, eps_data, dim: dJ_deps,
1521+
lambda self, E_der_map, eps_data, dim, freqs: dJ_deps,
15201522
)
15211523

15221524
import importlib

tidy3d/components/medium.py

+33-21
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import numpy as npo
1616
import pydantic.v1 as pd
1717
import xarray as xr
18+
from numpy.typing import NDArray
1819
from scipy import signal
1920

2021
from tidy3d.components.material.tcad.heat import ThermalSpecType
@@ -1383,15 +1384,14 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM
13831384
raise NotImplementedError(f"Can't compute derivative for 'Medium': '{type(self)}'.")
13841385

13851386
def derivative_eps_sigma_volume(
1386-
self,
1387-
E_der_map: ElectromagneticFieldDataset,
1388-
bounds: Bound,
1387+
self, E_der_map: ElectromagneticFieldDataset, bounds: Bound, freqs: NDArray
13891388
) -> dict[str, xr.DataArray]:
13901389
"""Get the derivative w.r.t permittivity and conductivity in the volume."""
13911390

1392-
vjp_eps_complex = self.derivative_eps_complex_volume(E_der_map=E_der_map, bounds=bounds)
1391+
vjp_eps_complex = self.derivative_eps_complex_volume(
1392+
E_der_map=E_der_map, bounds=bounds, freqs=freqs
1393+
)
13931394

1394-
freqs = vjp_eps_complex.coords["f"].values
13951395
values = vjp_eps_complex.values
13961396

13971397
eps_vjp, sigma_vjp = self.eps_complex_to_eps_sigma(eps_complex=values, freq=freqs)
@@ -1402,13 +1402,13 @@ def derivative_eps_sigma_volume(
14021402
return dict(permittivity=eps_vjp, conductivity=sigma_vjp)
14031403

14041404
def derivative_eps_complex_volume(
1405-
self, E_der_map: ElectromagneticFieldDataset, bounds: Bound
1405+
self, E_der_map: ElectromagneticFieldDataset, bounds: Bound, freqs: NDArray
14061406
) -> xr.DataArray:
14071407
"""Get the derivative w.r.t complex-valued permittivity in the volume."""
14081408

14091409
vjp_value = 0.0
14101410
for field_name in ("Ex", "Ey", "Ez"):
1411-
fld = E_der_map[field_name]
1411+
fld = E_der_map[field_name].sel(f=freqs)
14121412
vjp_value_fld = integrate_within_bounds(
14131413
arr=fld,
14141414
dims=("x", "y", "z"),
@@ -1667,6 +1667,7 @@ def _derivative_field_cmp(
16671667
E_der_map: ElectromagneticFieldDataset,
16681668
eps_data: PermittivityDataset,
16691669
dim: str,
1670+
freqs: NDArray,
16701671
) -> np.ndarray:
16711672
coords_interp = {key: val for key, val in eps_data.coords.items() if len(val) > 1}
16721673
dims_sum = {dim for dim in eps_data.coords.keys() if dim not in coords_interp}
@@ -1697,7 +1698,7 @@ def _derivative_field_cmp(
16971698
d_vol = np.array(1.0)
16981699

16991700
# TODO: probably this could be more robust. eg if the DataArray has weird edge cases
1700-
E_der_dim = E_der_map[f"E{dim}"]
1701+
E_der_dim = E_der_map[f"E{dim}"].sel(f=freqs)
17011702
E_der_dim_interp = (
17021703
E_der_dim.interp(**coords_interp, assume_sorted=True).fillna(0.0).sum(dims_sum).sum("f")
17031704
)
@@ -1919,7 +1920,9 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM
19191920

19201921
# get vjps w.r.t. permittivity and conductivity of the bulk
19211922
vjps_volume = self.derivative_eps_sigma_volume(
1922-
E_der_map=derivative_info.E_der_map, bounds=derivative_info.bounds
1923+
E_der_map=derivative_info.E_der_map,
1924+
bounds=derivative_info.bounds,
1925+
freqs=np.atleast_1d(derivative_info.frequency),
19231926
)
19241927

19251928
# store the fields asked for by ``field_paths``
@@ -1932,15 +1935,14 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM
19321935
return derivative_map
19331936

19341937
def derivative_eps_sigma_volume(
1935-
self,
1936-
E_der_map: ElectromagneticFieldDataset,
1937-
bounds: Bound,
1938+
self, E_der_map: ElectromagneticFieldDataset, bounds: Bound, freqs: NDArray
19381939
) -> dict[str, xr.DataArray]:
19391940
"""Get the derivative w.r.t permittivity and conductivity in the volume."""
19401941

1941-
vjp_eps_complex = self.derivative_eps_complex_volume(E_der_map=E_der_map, bounds=bounds)
1942+
vjp_eps_complex = self.derivative_eps_complex_volume(
1943+
E_der_map=E_der_map, bounds=bounds, freqs=freqs
1944+
)
19421945

1943-
freqs = vjp_eps_complex.coords["f"].values
19441946
values = vjp_eps_complex.values
19451947

19461948
# vjp of eps_complex_to_eps_sigma
@@ -1954,21 +1956,21 @@ def derivative_eps_sigma_volume(
19541956
return dict(permittivity=eps_vjp, conductivity=sigma_vjp)
19551957

19561958
def derivative_eps_complex_volume(
1957-
self, E_der_map: ElectromagneticFieldDataset, bounds: Bound
1959+
self, E_der_map: ElectromagneticFieldDataset, bounds: Bound, freqs: NDArray
19581960
) -> xr.DataArray:
19591961
"""Get the derivative w.r.t complex-valued permittivity in the volume."""
19601962

19611963
vjp_value = 0.0
19621964
for field_name in ("Ex", "Ey", "Ez"):
1963-
fld = E_der_map[field_name]
1965+
fld = E_der_map[field_name].sel(f=freqs)
19641966
vjp_value_fld = integrate_within_bounds(
19651967
arr=fld,
19661968
dims=("x", "y", "z"),
19671969
bounds=bounds,
19681970
)
19691971
vjp_value += vjp_value_fld
19701972

1971-
return vjp_value
1973+
return vjp_value.sum("f")
19721974

19731975

19741976
class CustomIsotropicMedium(AbstractCustomMedium, Medium):
@@ -2837,7 +2839,10 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM
28372839
vjp_array = 0.0
28382840
for dim in "xyz":
28392841
vjp_array += self._derivative_field_cmp(
2840-
E_der_map=derivative_info.E_der_map, eps_data=self.permittivity, dim=dim
2842+
E_der_map=derivative_info.E_der_map,
2843+
eps_data=self.permittivity,
2844+
dim=dim,
2845+
freqs=np.atleast_1d(derivative_info.frequency),
28412846
)
28422847
vjps[field_path] = vjp_array
28432848

@@ -2848,6 +2853,7 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM
28482853
E_der_map=derivative_info.E_der_map,
28492854
eps_data=self.eps_dataset.field_components[key],
28502855
dim=dim,
2856+
freqs=np.atleast_1d(derivative_info.frequency),
28512857
)
28522858

28532859
else:
@@ -2862,13 +2868,14 @@ def _derivative_field_cmp(
28622868
E_der_map: ElectromagneticFieldDataset,
28632869
eps_data: PermittivityDataset,
28642870
dim: str,
2871+
freqs: NDArray,
28652872
) -> np.ndarray:
28662873
"""Compute derivative with respect to the ``dim`` components within the custom medium."""
28672874

28682875
coords_interp = {key: eps_data.coords[key] for key in "xyz"}
28692876
coords_interp = {key: val for key, val in coords_interp.items() if len(val) > 1}
28702877

2871-
E_der_dim_interp = E_der_map[f"E{dim}"]
2878+
E_der_dim_interp = E_der_map[f"E{dim}"].sel(f=freqs)
28722879

28732880
for dim_ in "xyz":
28742881
if dim_ not in coords_interp:
@@ -3420,7 +3427,9 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM
34203427

34213428
# compute all derivatives beforehand
34223429
dJ_deps = self.derivative_eps_complex_volume(
3423-
E_der_map=derivative_info.E_der_map, bounds=derivative_info.bounds
3430+
E_der_map=derivative_info.E_der_map,
3431+
bounds=derivative_info.bounds,
3432+
freqs=np.atleast_1d(derivative_info.frequency),
34243433
)
34253434

34263435
dJ_deps = complex(dJ_deps)
@@ -3896,7 +3905,10 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM
38963905
dJ_deps = 0.0
38973906
for dim in "xyz":
38983907
dJ_deps += self._derivative_field_cmp(
3899-
E_der_map=derivative_info.E_der_map, eps_data=self.eps_inf, dim=dim
3908+
E_der_map=derivative_info.E_der_map,
3909+
eps_data=self.eps_inf,
3910+
dim=dim,
3911+
freqs=np.atleast_1d(derivative_info.frequency),
39003912
)
39013913

39023914
# TODO: fix for multi-frequency

0 commit comments

Comments
 (0)