Skip to content

Commit 4e73dc3

Browse files
author
Marko Pranjic
committed
Add comment explaining the encoding of the position information.
1 parent 561d409 commit 4e73dc3

File tree

2 files changed

+2
-0
lines changed

2 files changed

+2
-0
lines changed

model_pytorch.py

+1
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def __init__(self, cfg, vocab=40990, n_ctx=512):
165165
def forward(self, x):
166166
x = x.view(-1, x.size(-2), x.size(-1))
167167
e = self.embed(x)
168+
# Add the position information to the input embeddings
168169
h = e.sum(dim=2)
169170
for block in self.h:
170171
h = block(h)

train.py

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def transform_roc(X1, X2, X3):
3232
xmb[i, 1, :l13, 0] = x13
3333
mmb[i, 0, :l12] = 1
3434
mmb[i, 1, :l13] = 1
35+
# Position information that is added to the input embeddings in the TransformerModel
3536
xmb[:, :, :, 1] = np.arange(n_vocab + n_special, n_vocab + n_special + n_ctx)
3637
return xmb, mmb
3738

0 commit comments

Comments
 (0)