Skip to content

[Bugfix] Fix moe weight losing all extra attrs after process_weights_after_loading. #16854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,9 @@ def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)

layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w13_weight.data),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w2_weight.data),
requires_grad=False)
# Padding the weight for better performance on ROCm
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
# Lazy import to avoid importing triton.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled, shuffle_weights)
Expand All @@ -127,10 +124,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data)

layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False)
layer.w13_weight.data = shuffled_w13
layer.w2_weight.data = shuffled_w2

if current_platform.is_cpu():
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List) -> torch.Tensor:
from vllm.platforms.rocm import on_mi250_mi300
if envs.VLLM_ROCM_USE_SKINNY_GEMM and not on_mi250_mi300(
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300(
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
current_platform.get_cu_count())
Expand Down