Skip to content

Commit d3608d6

Browse files
authored
Remove dummy forward path (#3669)
Remove dummy forward path
1 parent dbd9a83 commit d3608d6

File tree

8 files changed

+12
-107
lines changed

8 files changed

+12
-107
lines changed

docs/source/torch/attention.md

-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ It contains the following predefined fields:
6565
| request_ids | List[int] | The request ID of each sequence in the batch. |
6666
| prompt_lens | List[int] | The prompt length of each sequence in the batch. |
6767
| kv_cache_params | KVCacheParams | The parameters for the KV cache. |
68-
| is_dummy_attention | bool | Indicates whether this is a simulation-only attention operation used for KV cache memory estimation. Defaults to False. |
6968

7069
During `AttentionMetadata.__init__`, you can initialize additional fields for the new attention metadata.
7170
For example, the Flashinfer metadata initializes `decode_wrapper` here.

tensorrt_llm/_torch/attention_backend/flashinfer.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from ..utils import get_global_attrs, get_model_extra_attrs
1414
from .interface import (AttentionBackend, AttentionMask, AttentionMetadata,
15-
PredefinedAttentionMask, dummy_forward)
15+
PredefinedAttentionMask)
1616

1717
try:
1818
check_cuda_arch()
@@ -465,14 +465,6 @@ def forward_pattern(
465465
else:
466466
metadata = get_global_attrs().attention_metadata()
467467

468-
# This is only for memory estimation for now.
469-
# NOTE: this method is not accurate while it works for most scenario.
470-
if metadata is None or metadata.kv_cache_manager is None:
471-
q = q.view(-1, num_heads, head_dim)
472-
k = k.view(-1, num_kv_heads, head_dim)
473-
v = v.view(-1, num_kv_heads, head_dim)
474-
return dummy_forward(q, k, v)
475-
476468
assert isinstance(
477469
metadata,
478470
FlashInferAttentionMetadata,

tensorrt_llm/_torch/attention_backend/interface.py

-35
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
Union)
88

99
import torch
10-
from transformers.modeling_flash_attention_utils import _flash_attention_forward
1110
from typing_extensions import Self
1211

1312
from tensorrt_llm.functional import (PositionEmbeddingType, RopeEmbeddingUtils,
@@ -125,7 +124,6 @@ class AttentionMetadata:
125124
_num_generations: int = field(init=False, default=0, repr=False)
126125
_num_ctx_tokens: int = field(init=False, default=0, repr=False)
127126
_num_tokens: int = field(init=False, default=0, repr=False)
128-
is_dummy_attention: bool = False
129127

130128
def __post_init__(self) -> None:
131129
if self.is_cross:
@@ -548,36 +546,3 @@ class MLAParams:
548546
qk_nope_head_dim: int = 0
549547
v_head_dim: int = 0
550548
predicted_tokens_per_seq: int = 1
551-
552-
553-
@torch.library.custom_op("trtllm::attn_dummy_fwd", mutates_args=())
554-
def dummy_forward(q: torch.Tensor, k: torch.Tensor,
555-
v: torch.Tensor) -> torch.Tensor:
556-
"""
557-
Dummy attention forward function to estimate memory usage.
558-
Args:
559-
q (torch.Tensor): Query tensor with shape (num_q_tokens, num_heads, head_dim),.
560-
k (torch.Tensor): Key tensor with shape (num_new_kv_tokens, num_kv_heads, head_dim)
561-
v (torch.Tensor): Value tensor with shape (num_new_kv_tokens, num_kv_heads, head_dim)
562-
Returns:
563-
torch.Tensor with shape (num_q_tokens, num_heads * head_dim)
564-
"""
565-
head_dim = q.shape[2]
566-
assert q.dim() == 3
567-
assert k.dim() == 3 and k.size(2) == head_dim
568-
assert v.dim() == 3 and v.size(2) == head_dim
569-
# This is only for memory estimation for now.
570-
# NOTE: this method is not accurate while it works for most scenario.
571-
o = _flash_attention_forward(q.unsqueeze(0),
572-
k.unsqueeze(0),
573-
v.unsqueeze(0),
574-
attention_mask=None,
575-
query_length=q.size(0),
576-
is_causal=True)
577-
return o.reshape(o.size(1), -1)
578-
579-
580-
@dummy_forward.register_fake
581-
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
582-
num_q_tokens = q.size(0)
583-
return torch.empty_like(q).reshape(num_q_tokens, -1)

tensorrt_llm/_torch/attention_backend/star_flashinfer.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010

1111
from ..distributed import allgather
1212
from .flashinfer import FlashInferAttentionMetadata, PlanParams
13-
from .interface import (AttentionBackend, AttentionMask,
14-
PredefinedAttentionMask, dummy_forward)
13+
from .interface import AttentionBackend, AttentionMask, PredefinedAttentionMask
1514

1615

1716
# Please sync with flashinfer's DISPATCH_GQA_GROUP_SIZE in include/flashinfer/utils.cuh
@@ -326,11 +325,6 @@ def forward(self,
326325
k = k.view(-1, self.num_kv_heads, self.head_dim)
327326
v = v.view(-1, self.num_kv_heads, self.head_dim)
328327

329-
# This is only for memory estimation for now.
330-
# NOTE: this method is not accurate while it works for most scenario.
331-
if metadata is None or metadata.kv_cache_manager is None:
332-
return dummy_forward(q, k, v)
333-
334328
num_contexts = metadata.num_contexts
335329
num_queries = metadata.num_queries
336330
num_generations = metadata.num_generations

tensorrt_llm/_torch/attention_backend/trtllm.py

+2-27
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .interface import (AttentionBackend, AttentionInputType, AttentionMask,
1111
AttentionMetadata, KVCacheParams, MLAParams,
1212
PositionalEmbeddingParams, PredefinedAttentionMask,
13-
RopeParams, dummy_forward)
13+
RopeParams)
1414

1515

1616
@dataclass(kw_only=True, init=False)
@@ -489,7 +489,7 @@ def __post_init__(self) -> None:
489489

490490
def prepare(self) -> None:
491491

492-
if not self.is_dummy_attention and self.kv_cache_manager is None:
492+
if self.kv_cache_manager is None:
493493
# Convert the attention metadata to a TRT-LLM no cache attention metadata.
494494
assert self.kv_cache_manager is None, "no cache attention should not have KV cache manager"
495495
assert self._max_seq_len_storage is not None, "max_seq_len should be set for no cache attention"
@@ -641,31 +641,6 @@ def forward(
641641
mrope_config: Optional[dict] = None,
642642
**kwargs,
643643
) -> torch.Tensor:
644-
# This is only for memory estimation for now.
645-
# NOTE: this method is not accurate while it works for most scenario.
646-
if metadata.is_dummy_attention:
647-
q_size = self.num_heads * self.head_dim
648-
k_size = self.num_kv_heads * self.head_dim
649-
v_size = self.num_kv_heads * self.v_head_dim
650-
q, k, v = q.split([q_size, k_size, v_size], dim=-1)
651-
q = q.view(-1, self.num_heads, self.head_dim)
652-
k = k.view(-1, self.num_kv_heads, self.head_dim)
653-
v = v.view(-1, self.num_kv_heads, self.v_head_dim)
654-
if self.head_dim != self.v_head_dim:
655-
# the dummy forward doesn't support head_dim != v_head_dim case
656-
# so we use a tensor with supported shape to replace the v
657-
# the memory estimation is not accurate in this case
658-
v = torch.randn(q.shape[0],
659-
self.num_kv_heads,
660-
self.head_dim,
661-
dtype=q.dtype,
662-
device=q.device)
663-
output = dummy_forward(q, k, v)
664-
if self.head_dim != self.v_head_dim:
665-
output = output[..., :self.num_kv_heads *
666-
self.v_head_dim].contiguous()
667-
return output
668-
669644
assert isinstance(
670645
metadata,
671646
TrtllmAttentionMetadata,

tensorrt_llm/_torch/attention_backend/vanilla.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
AttentionMaskConverter = None
1212

1313
from .interface import (AttentionBackend, AttentionMask, AttentionMetadata,
14-
PredefinedAttentionMask, dummy_forward)
14+
PredefinedAttentionMask)
1515

1616

1717
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@@ -230,12 +230,7 @@ def forward(self,
230230
*,
231231
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
232232
**kwargs) -> torch.Tensor:
233-
234-
# This is only for memory estimation for now.
235-
# NOTE: this method is not accurate while it works for most scenario.
236-
if metadata.is_dummy_attention:
237-
return dummy_forward(q, k, v)
238-
elif metadata.kv_cache_manager is None:
233+
if metadata.kv_cache_manager is None:
239234
# NOTE: WAR for no kv cache attn e.g. BERT,
240235
# try to separate the kv cache estimation path from no kv cache attn.
241236
num_heads = self.num_heads
@@ -249,11 +244,6 @@ def forward(self,
249244
metadata=metadata,
250245
attention_mask=attention_mask)
251246

252-
# This is only for memory estimation for now.
253-
# NOTE: this method is not accurate while it works for most scenario.
254-
if metadata is None or metadata.kv_cache_manager is None:
255-
return dummy_forward(q, k, v)
256-
257247
past_seen_tokens = metadata.kv_cache_params.num_cached_tokens_per_seq
258248
cache_indices = [
259249
block_ids[0] for block_ids in metadata.block_ids_per_seq

tensorrt_llm/_torch/pyexecutor/model_engine.py

+5-13
Original file line numberDiff line numberDiff line change
@@ -576,21 +576,15 @@ def _create_extra_inputs(bs, num_tokens_per_request):
576576
extra_model_inputs=_create_extra_inputs(bs, 1))
577577
torch.cuda.synchronize()
578578

579-
def _set_up_attn_metadata(self,
580-
kv_cache_manager: KVCacheManager,
581-
is_dummy_forward: bool = False):
582-
# is_dummy_forward is used to indicate whether the forward is
583-
# a dummy forward for memory estimation OR
584-
# a real forward w.o. kv cache
579+
def _set_up_attn_metadata(self, kv_cache_manager: KVCacheManager):
585580
if kv_cache_manager is None:
586581
return self.attn_backend.Metadata(
587582
max_num_requests=self.batch_size,
588583
max_num_tokens=self.max_num_tokens,
589584
kv_cache_manager=None,
590585
mapping=self.mapping,
591586
runtime_features=self.attn_runtime_features,
592-
enable_flash_mla=self.model.model_config.enable_flash_mla,
593-
is_dummy_attention=is_dummy_forward)
587+
enable_flash_mla=self.model.model_config.enable_flash_mla)
594588

595589
if self.attn_metadata is not None:
596590
# This assertion can be relaxed if needed: just create a new metadata
@@ -1282,7 +1276,7 @@ def _prepare_tp_inputs_no_cache(
12821276
all_rank_num_tokens = self.dist.allgather(attn_metadata.num_tokens)
12831277
attn_metadata.all_rank_num_tokens = all_rank_num_tokens
12841278
# this is for no cache attention, not for dummy attention
1285-
if not attn_metadata.is_dummy_attention and attn_metadata.kv_cache_manager is None:
1279+
if attn_metadata.kv_cache_manager is None:
12861280
assert isinstance(
12871281
attn_metadata,
12881282
(VanillaAttentionMetadata, TrtllmAttentionMetadata)
@@ -1596,14 +1590,12 @@ def forward(self,
15961590
scheduled_requests: ScheduledRequests,
15971591
resource_manager: ResourceManager,
15981592
new_tensors_device: Optional[Dict[str, torch.Tensor]] = None,
1599-
extra_model_inputs: Optional[Dict[str, Any]] = None,
1600-
is_dummy_forward: bool = False):
1593+
extra_model_inputs: Optional[Dict[str, Any]] = None):
16011594

16021595
kv_cache_manager = resource_manager.get_resource_manager(
16031596
self.kv_cache_manager_key)
16041597

1605-
attn_metadata = self._set_up_attn_metadata(kv_cache_manager,
1606-
is_dummy_forward)
1598+
attn_metadata = self._set_up_attn_metadata(kv_cache_manager)
16071599
if self.spec_config is not None:
16081600
spec_resource_manager = resource_manager.get_resource_manager(
16091601
'spec_resource_manager')

tests/unittest/_torch/test_attention_no_cache.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,7 @@ def _run_test_for_backend(backend_name, num_heads, num_kv_heads, num_layers,
246246
max_num_tokens=8192,
247247
kv_cache_manager=None,
248248
mapping=None,
249-
runtime_features=None,
250-
is_dummy_attention=False,
251-
)
249+
runtime_features=None)
252250

253251
# NOTE: set up metadata
254252
attn_metadata.seq_lens = torch.tensor(sequence_lengths, dtype=torch.int)

0 commit comments

Comments
 (0)