You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
PyTorch version: 2.6.0+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.28.1
Libc version: glibc-2.31
Python version: 3.11.11 | packaged by conda-forge | (main, Mar 3 2025, 20:43:55) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.3.107
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 4090
GPU 1: NVIDIA GeForce RTX 4090
Nvidia driver version: 550.107.02
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.3.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 48 bits physical, 48 bits virtual
CPU(s): 256
On-line CPU(s) list: 0-255
Thread(s) per core: 2
Core(s) per socket: 64
Socket(s): 2
NUMA node(s): 2
Vendor ID: AuthenticAMD
CPU family: 25
Model: 1
Model name: AMD EPYC 7763 64-Core Processor
Stepping: 1
Frequency boost: enabled
CPU MHz: 1500.000
CPU max MHz: 3529.0520
CPU min MHz: 1500.0000
BogoMIPS: 4900.38
Virtualization: AMD-V
L1d cache: 4 MiB
L1i cache: 4 MiB
L2 cache: 64 MiB
L3 cache: 512 MiB
NUMA node0 CPU(s): 0-63,128-191
NUMA node1 CPU(s): 64-127,192-255
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-ml-py==12.570.86
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pyzmq==26.3.0
[pip3] torch==2.6.0
[pip3] torchaudio==2.6.0
[pip3] torchvision==0.21.0
[pip3] transformers==4.51.2
[pip3] triton==3.2.0
[conda] numpy 1.26.4 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.4.5.8 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.2.1.3 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.5.147 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.6.1.9 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.3.1.170 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.6.2 pypi_0 pypi
[conda] nvidia-ml-py 12.570.86 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.21.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.4.127 pypi_0 pypi
[conda] pyzmq 26.3.0 pypi_0 pypi
[conda] torch 2.6.0 pypi_0 pypi
[conda] torchaudio 2.6.0 pypi_0 pypi
[conda] torchvision 0.21.0 pypi_0 pypi
[conda] transformers 4.51.2 pypi_0 pypi
[conda] triton 3.2.0 pypi_0 pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.8.3
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0 GPU1 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X SYS 0-63,128-191 0 N/A
GPU1 SYS X 64-127,192-255 1 N/A
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
LD_LIBRARY_PATH=/home/whzhang/workspace/LLM/ollama/lib/ollama:/usr/local/cuda/lib64:/home/whzhang/workspace/LLM/ollama/lib/ollama:/usr/local/cuda/lib64:/home/whzhang/workspace/LLM/ollama/lib/ollama:/usr/local/cuda/lib64:
CUDA_VISIBLE_DEVICES=0,1
CUDA_VISIBLE_DEVICES=0,1
TORCH_USE_CUDA_DSA=1
CUDA_LAUNCH_BLOCKING=1
PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:256
VLLM_WORKER_MULTIPROC_METHOD=spawn
NCCL_CUMEM_ENABLE=0
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY
🐛 Describe the bug
When the benchmark_serving.py is started for the first time of an num_prompt, the TTFT is very long. It will be much better the second time.
the first time
============ Serving Benchmark Result ============
Successful requests: 10
Benchmark duration (s): 8.85
Total input tokens: 3354
Total generated tokens: 1200
Request throughput (req/s): 1.13
Output token throughput (tok/s): 135.58
Total Token throughput (tok/s): 514.51
---------------Time to First Token----------------
Mean TTFT (ms): 361.76
Median TTFT (ms): 421.65
P99 TTFT (ms): 709.10
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 29.22
Median TPOT (ms): 29.21
P99 TPOT (ms): 32.98
---------------Inter-token Latency----------------
Mean ITL (ms): 32.65
Median ITL (ms): 27.44
P99 ITL (ms): 231.45
==================================================
the second time
============ Serving Benchmark Result ============
Successful requests: 10
Benchmark duration (s): 8.79
Total input tokens: 3354
Total generated tokens: 1200
Request throughput (req/s): 1.14
Output token throughput (tok/s): 136.47
Total Token throughput (tok/s): 517.91
---------------Time to First Token----------------
Mean TTFT (ms): 45.84
Median TTFT (ms): 45.04
P99 TTFT (ms): 62.59
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 27.12
Median TPOT (ms): 27.19
P99 TPOT (ms): 27.80
---------------Inter-token Latency----------------
Mean ITL (ms): 27.12
Median ITL (ms): 26.95
P99 ITL (ms): 31.99
==================================================
I also observed another phenomenon, that is, the TTFT obtained is quite different from the measured iteration execution time. Is it because the time of processing input and output is not taken into account? If I want to perform some personalized scheduling on the server side based on TTFT/TPOT, how can I get the accurate TTFT on the server side? Or is it that only the actual calculation time is meaningful on the server side?
I modified vllm/v1/engine/core.py/EngineCore.step() to measure the computation time of one iteration:
defstep(self) ->EngineCoreOutputs:
"""Schedule, execute, and make output."""withopen(self.log_path, 'a') aslog_file:
# last iteration timeifself.last_iter_begin_timeandself.total_num_scheduled_tokensandnotself.scheduler.get_is_warmup():
message=f"iter computing time: {float(time.perf_counter_ns() -self.last_iter_begin_time)/1e6} ms\n"print(message)
log_file.write(message+'\n')
self.last_iter_begin_time=time.perf_counter_ns()
# Check for any requests remaining in the scheduler - unfinished,# or finished and not yet removed from the batch.ifnotself.scheduler.has_requests():
returnEngineCoreOutputs(
outputs=[],
scheduler_stats=self.scheduler.make_stats(),
)
scheduler_output=self.scheduler.schedule()
output=self.model_executor.execute_model(scheduler_output)
engine_core_outputs=self.scheduler.update_from_output(
scheduler_output, output) # type: ignoreself.total_num_scheduled_tokens=scheduler_output.total_num_scheduled_tokensifself.total_num_scheduled_tokensandnotself.scheduler.get_is_warmup():
# new iteration infomessage=f"\033[93mnew iteration, token budget: {self.scheduler.get_token_budget()}, total scheduled tokens: {self.total_num_scheduled_tokens}\033[0m"print(message)
message=f"new iteration, token budget: {self.scheduler.get_token_budget()}, total scheduled tokens: {self.total_num_scheduled_tokens}"log_file.write(message+'\n')
# new request infoforreqinscheduler_output.scheduled_new_reqs:
new_req_idx=self.request_idx_dict[req.req_id]
new_req_lora_id=self.request_lora_id_dict[req.req_id]
new_req_lora_rank=self.lora_model_rank_dict[new_req_lora_id]
new_req_scheduled_tokens=scheduler_output.num_scheduled_tokens[req.req_id]
message=f"new request id: {new_req_idx}, scheduled tokens: {new_req_scheduled_tokens}, lora id: {new_req_lora_id}, lora rank: {new_req_lora_rank}"print(message)
log_file.write(message+'\n')
# cached request infoforreqinscheduler_output.scheduled_cached_reqs:
cached_req_idx=self.request_idx_dict[req.req_id]
cached_req_lora_id=self.request_lora_id_dict[req.req_id]
cached_req_lora_rank=self.lora_model_rank_dict[cached_req_lora_id]
cached_req_scheduled_tokens=scheduler_output.num_scheduled_tokens[req.req_id]
message=f"cached request id: {cached_req_idx}, scheduled tokens: {cached_req_scheduled_tokens}, lora id: {cached_req_lora_id}, lora rank: {cached_req_lora_rank}"print(message)
log_file.write(message+'\n')
returnengine_core_outputs
And I got the max iteration time was below, it means TTFT of request 7. It is only about 33ms, which is smaller than the TTFT measured(mean TTFT is about 45ms):
Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
The text was updated successfully, but these errors were encountered:
sjtu-zwh
changed the title
[Bug]: abnormal TTFT (too long) in the first serving
[Bug]: vllm 0.8.3 abnormal TTFT (too long) in the first serving
Apr 19, 2025
Your current environment
The output of `python collect_env.py`
🐛 Describe the bug
When the benchmark_serving.py is started for the first time of an num_prompt, the TTFT is very long. It will be much better the second time.
I also observed another phenomenon, that is, the TTFT obtained is quite different from the measured iteration execution time. Is it because the time of processing input and output is not taken into account? If I want to perform some personalized scheduling on the server side based on TTFT/TPOT, how can I get the accurate TTFT on the server side? Or is it that only the actual calculation time is meaningful on the server side?
I modified
vllm/v1/engine/core.py/EngineCore.step()
to measure the computation time of one iteration:And I got the max iteration time was below, it means TTFT of request 7. It is only about 33ms, which is smaller than the TTFT measured(mean TTFT is about 45ms):
Before submitting a new issue...
The text was updated successfully, but these errors were encountered: