Skip to content

Commit 3f4df26

Browse files
committed
more bugfixes in export from gpu (#66)
1 parent 91b5d93 commit 3f4df26

File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed

fdtd/backend.py

-2
Original file line numberDiff line numberDiff line change
@@ -338,8 +338,6 @@ def array(self, arr, dtype=None, **kwargs):
338338
return arr.clone().to(device="cuda", dtype=dtype, **kwargs)
339339
return torch.tensor(arr, device="cuda", dtype=dtype, **kwargs)
340340

341-
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
342-
# The same warning applies here.
343341
def numpy(self, arr):
344342
"""convert the array to numpy array"""
345343
if torch.is_tensor(arr):

fdtd/grid.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -494,12 +494,21 @@ def save_data(self):
494494
495495
Parameters: None
496496
"""
497+
def _numpyfy(item):
498+
if isinstance(item, list):
499+
return [_numpyfy(el) for el in item]
500+
elif bd.is_array(item):
501+
return bd.numpy(item)
502+
else:
503+
return item
504+
497505
if self.folder is None:
498506
raise Exception(
499507
"Save location not initialized. Please read about 'fdtd.Grid.saveSimulation()' or try running 'grid.saveSimulation()'."
500508
)
501509
dic = {}
502510
for detector in self.detectors:
503-
dic[detector.name + " (E)"] = [x for x in detector.detector_values()["E"]]
504-
dic[detector.name + " (H)"] = [x for x in detector.detector_values()["H"]]
511+
values = detector.detector_values()
512+
dic[detector.name + " (E)"] = _numpyfy(values['E'])
513+
dic[detector.name + " (H)"] = _numpyfy(values['H'])
505514
savez(path.join(self.folder, "detector_readings"), **dic)

fdtd/visualization.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -360,14 +360,14 @@ def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16"):
360360
a[i].append(max(temp) - min(temp))
361361

362362
peakVal, minVal = max(map(max, a)), min(map(min, a))
363-
print(
364-
"Peak at:",
365-
[
366-
[[i, j] for j, y in enumerate(x) if y == peakVal]
367-
for i, x in enumerate(a)
368-
if peakVal in x
369-
],
370-
)
363+
#print(
364+
# "Peak at:",
365+
# [
366+
# [[i, j] for j, y in enumerate(x) if y == peakVal]
367+
# for i, x in enumerate(a)
368+
# if peakVal in x
369+
# ],
370+
#)
371371
a = 10 * log10([[y / minVal for y in x] for x in a])
372372

373373
plt.title("dB map of Electrical waves in detector region")

0 commit comments

Comments
 (0)