Skip to content

Commit 7d2d288

Browse files
author
wabywang(王本友)
committed
GPU_Variable_shape_debug
1 parent 5f9bba3 commit 7d2d288

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

models/Transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ def forward(self, inp):
425425
src_seq,src_pos = inp
426426
# enc_output, *_ = self.encoder(src_seq, src_pos) #64x200x512
427427
enc_output = self.encoder(src_seq, src_pos) #64x200x512
428-
return self.hidden2label(enc_output.view((enc_output.shape[0],-1)))
428+
return self.hidden2label(enc_output.view((self.batch_size,-1)))
429429

430430

431431

0 commit comments

Comments
 (0)