Skip to content

Commit e072fd0

Browse files
author
samdow
committed
embedding decomp
1 parent 915aecb commit e072fd0

File tree

2 files changed

+2
-0
lines changed

2 files changed

+2
-0
lines changed

functorch/_src/eager_transforms.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,5 +1340,6 @@ def _register_python_decomposition_vmap(decomp):
13401340
_register_jit_decomposition(torch.ops.aten.log_sigmoid_forward.default)
13411341
_register_jit_decomposition(torch.ops.aten.native_layer_norm_backward.default)
13421342
_register_jit_decomposition(torch.ops.aten.native_batch_norm_backward.default, use_python=True)
1343+
_register_jit_decomposition(torch.ops.aten.embedding_dense_backward.default)
13431344
_register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default)
13441345
_register_python_decomposition_vmap(torch.ops.aten.addr.default)

functorch/csrc/DynamicLayer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,7 @@ TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) {
503503
JVP_DECOMP(log_sigmoid_forward);
504504
JVP_DECOMP(native_layer_norm_backward);
505505
JVP_DECOMP(native_batch_norm_backward);
506+
JVP_DECOMP(embedding_dense_backward);
506507
}
507508

508509

0 commit comments

Comments
 (0)