Skip to content

Commit f521611

Browse files
charlifulk-chen
authored andcommitted
[Bugfix] Fix moe weight losing all extra attrs after process_weights_after_loading. (vllm-project#16854)
Signed-off-by: charlifu <charlifu@amd.com>
1 parent bfffc90 commit f521611

File tree

2 files changed

+6
-11
lines changed

2 files changed

+6
-11
lines changed

vllm/model_executor/layers/fused_moe/layer.py

+5-10
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,9 @@ def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
113113
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
114114
super().process_weights_after_loading(layer)
115115

116-
layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
117-
layer.w13_weight.data),
118-
requires_grad=False)
119-
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
120-
layer.w2_weight.data),
121-
requires_grad=False)
116+
# Padding the weight for better performance on ROCm
117+
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
118+
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
122119
# Lazy import to avoid importing triton.
123120
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
124121
is_rocm_aiter_moe_enabled, shuffle_weights)
@@ -127,10 +124,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
127124
shuffled_w13, shuffled_w2 = shuffle_weights(
128125
layer.w13_weight.data, layer.w2_weight.data)
129126

130-
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
131-
requires_grad=False)
132-
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
133-
requires_grad=False)
127+
layer.w13_weight.data = shuffled_w13
128+
layer.w2_weight.data = shuffled_w2
134129

135130
if current_platform.is_cpu():
136131
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:

vllm/model_executor/layers/quantization/utils/w8a8_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
156156
input_2d: torch.Tensor,
157157
output_shape: List) -> torch.Tensor:
158158
from vllm.platforms.rocm import on_mi250_mi300
159-
if envs.VLLM_ROCM_USE_SKINNY_GEMM and not on_mi250_mi300(
159+
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300(
160160
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
161161
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
162162
current_platform.get_cu_count())

0 commit comments

Comments
 (0)