Skip to content

Commit f8d6936

Browse files
authored
Merge pull request #26 from boazcogan/fix/issue_18
Fix to issue #18 and correcting rfft call.
2 parents ef60ba2 + 9208681 commit f8d6936

File tree

2 files changed

+18
-17
lines changed

2 files changed

+18
-17
lines changed

book/training/building_blocks.ipynb

+5-6
Original file line numberDiff line numberDiff line change
@@ -1606,14 +1606,13 @@
16061606
"fig, ax = plt.subplots(8, 8, figsize=(12, 10))\n",
16071607
"ax = ax.flatten()\n",
16081608
"\n",
1609-
"fourier_basis = torch.rfft(\n",
1610-
" torch.eye(filter_length), \n",
1611-
" 1, onesided=True\n",
1609+
"fourier_basis = torch.fft.rfft(\n",
1610+
" torch.eye(filter_length)\n",
16121611
")\n",
16131612
"cutoff = 1 + filter_length // 2\n",
16141613
"fourier_basis = torch.cat([\n",
1615-
" fourier_basis[:, :cutoff, 0],\n",
1616-
" fourier_basis[:, :cutoff, 1]\n",
1614+
" torch.real(fourier_basis[:, :cutoff]),\n",
1615+
" torch.imag(fourier_basis[:, :cutoff])\n",
16171616
"], dim=1)\n",
16181617
"fourier_basis = fourier_basis.float()\n",
16191618
"\n",
@@ -1729,7 +1728,7 @@
17291728
"name": "python",
17301729
"nbconvert_exporter": "python",
17311730
"pygments_lexer": "ipython3",
1732-
"version": "3.8.5"
1731+
"version": "3.8.8"
17331732
}
17341733
},
17351734
"nbformat": 4,

book/training/gradient_descent.ipynb

+13-11
Original file line numberDiff line numberDiff line change
@@ -365,13 +365,14 @@
365365
" possible_b.shape[0]\n",
366366
"))\n",
367367
"\n",
368-
"for i, m_hat in enumerate(possible_m):\n",
369-
" for j, b_hat in enumerate(possible_b):\n",
370-
" line.layer.weight[0, 0] = m_hat\n",
371-
" line.layer.bias[0] = b_hat\n",
372-
" y_hat = line(to_tensor(x))\n",
373-
" _loss = (y_hat - to_tensor(y)).abs().mean()\n",
374-
" loss[i, j] = _loss\n",
368+
"with torch.no_grad():\n",
369+
" for i, m_hat in enumerate(possible_m):\n",
370+
" for j, b_hat in enumerate(possible_b):\n",
371+
" line.layer.weight[0, 0] = m_hat\n",
372+
" line.layer.bias[0] = b_hat\n",
373+
" y_hat = line(to_tensor(x))\n",
374+
" _loss = (y_hat - to_tensor(y)).abs().mean()\n",
375+
" loss[i, j] = _loss\n",
375376
"\n",
376377
"plt.title('2D Visualization of Loss Landscape')\n",
377378
"plt.pcolormesh(possible_m, possible_b, loss, shading='auto')\n",
@@ -442,10 +443,11 @@
442443
}
443444
],
444445
"source": [
445-
"line.layer.weight[0, 0] = m_hat\n",
446-
"line.layer.bias[0] = b_hat\n",
446+
"with torch.no_grad():\n",
447+
" line.layer.weight[0, 0] = m_hat\n",
448+
" line.layer.bias[0] = b_hat\n",
447449
"\n",
448-
"y_hat = line(to_tensor(x))\n",
450+
" y_hat = line(to_tensor(x))\n",
449451
"\n",
450452
"plt.title(\"Training data + network predictions\")\n",
451453
"plt.scatter(x, y, label='Training data')\n",
@@ -1014,7 +1016,7 @@
10141016
"name": "python",
10151017
"nbconvert_exporter": "python",
10161018
"pygments_lexer": "ipython3",
1017-
"version": "3.8.5"
1019+
"version": "3.8.8"
10181020
}
10191021
},
10201022
"nbformat": 4,

0 commit comments

Comments
 (0)