@@ -113,12 +113,9 @@ def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
113
113
def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
114
114
super ().process_weights_after_loading (layer )
115
115
116
- layer .w13_weight = torch .nn .Parameter (self ._maybe_pad_weight (
117
- layer .w13_weight .data ),
118
- requires_grad = False )
119
- layer .w2_weight = torch .nn .Parameter (self ._maybe_pad_weight (
120
- layer .w2_weight .data ),
121
- requires_grad = False )
116
+ # Padding the weight for better performance on ROCm
117
+ layer .w13_weight .data = self ._maybe_pad_weight (layer .w13_weight .data )
118
+ layer .w2_weight .data = self ._maybe_pad_weight (layer .w2_weight .data )
122
119
# Lazy import to avoid importing triton.
123
120
from vllm .model_executor .layers .fused_moe .rocm_aiter_fused_moe import (
124
121
is_rocm_aiter_moe_enabled , shuffle_weights )
@@ -127,10 +124,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
127
124
shuffled_w13 , shuffled_w2 = shuffle_weights (
128
125
layer .w13_weight .data , layer .w2_weight .data )
129
126
130
- layer .w13_weight = torch .nn .Parameter (shuffled_w13 ,
131
- requires_grad = False )
132
- layer .w2_weight = torch .nn .Parameter (shuffled_w2 ,
133
- requires_grad = False )
127
+ layer .w13_weight .data = shuffled_w13
128
+ layer .w2_weight .data = shuffled_w2
134
129
135
130
if current_platform .is_cpu ():
136
131
if current_platform .get_cpu_architecture () == CpuArchEnum .X86 :
0 commit comments