Skip to content

Commit e12d350

Browse files
authored
update group topk (#175)
* update group topk * update
1 parent c0da3d2 commit e12d350

File tree

5 files changed

+328
-40
lines changed

5 files changed

+328
-40
lines changed

aiter/ops/topk.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,20 @@ def biased_grouped_topk(
1717
num_expert_group: int,
1818
topk_group: int,
1919
need_renorm: bool,
20-
routed_scaling_factor: float # mul to topk_weights
20+
routed_scaling_factor: float=1.0 # mul to topk_weights
2121
): ...
2222

23+
@compile_ops("module_moe_asm")
24+
def grouped_topk(
25+
gating_output: Tensor,
26+
topk_weights: Tensor,
27+
topk_ids: Tensor,
28+
num_expert_group: int,
29+
topk_group: int,
30+
need_renorm: bool,
31+
scoring_func: str="softmax",
32+
scale_factor: float=1.0,
33+
): ...
2334

2435
# this one copied from sglang
2536
def biased_grouped_topk_torch(
@@ -62,3 +73,43 @@ def biased_grouped_topk_torch(
6273
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
6374

6475
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
76+
77+
78+
# this one copied from sglang
79+
def grouped_topk_torch(
80+
gating_output: torch.Tensor,
81+
topk: int,
82+
renormalize: bool,
83+
num_expert_group: int = 0,
84+
topk_group: int = 0,
85+
scoring_func: str = "softmax",
86+
):
87+
gating_output = gating_output.to(torch.float)
88+
if scoring_func == "softmax":
89+
scores = torch.softmax(gating_output, dim=-1)
90+
elif scoring_func == "sigmoid":
91+
scores = gating_output.sigmoid()
92+
else:
93+
raise ValueError(f"Scoring function '{scoring_func}' is not supported.")
94+
95+
num_token = scores.shape[0]
96+
group_scores = (
97+
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
98+
) # [n, n_group]
99+
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
100+
1
101+
] # [n, top_k_group]
102+
group_mask = torch.zeros_like(group_scores) # [n, n_group]
103+
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
104+
score_mask = (
105+
group_mask.unsqueeze(-1)
106+
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
107+
.reshape(num_token, -1)
108+
) # [n, e]
109+
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
110+
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
111+
112+
if renormalize:
113+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
114+
115+
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)

csrc/include/moe_op.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,17 @@ void biased_grouped_topk(
1616
int num_expert_group,
1717
int topk_group,
1818
bool renormalize,
19-
const float routed_scaling_factor);
19+
const float routed_scaling_factor = 1.);
20+
21+
void grouped_topk(
22+
torch::Tensor &gating_output, // [num_tokens, num_experts]
23+
torch::Tensor &topk_weights, // [num_tokens, topk]
24+
torch::Tensor &topk_ids, // [num_tokens, topk]
25+
int num_expert_group,
26+
int topk_grp,
27+
bool need_renorm,
28+
std::string scoring_func = "softmax",
29+
const float routed_scaling_factor = 1.);
2030

2131
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
2232
int64_t block_size, torch::Tensor sorted_token_ids,

csrc/kernels/topk_softmax_kernels_group.cu

Lines changed: 200 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,23 @@
1919
#define WARP_SIZE 64
2020
namespace aiter
2121
{
22+
template <typename T, typename F>
23+
__device__ constexpr T wave_reduce(T local, F reduce_f)
24+
{
25+
constexpr int reduce_stage = 6; // 1<<6=64
26+
T v_local = local;
27+
#pragma unroll
28+
for (int i_stage = 0; i_stage < reduce_stage; i_stage++)
29+
{
30+
int src_lane = __lane_id() ^ (1 << i_stage);
31+
int32_t v_remote_tmp =
32+
__builtin_amdgcn_ds_bpermute(src_lane << 2, __builtin_bit_cast(int32_t, v_local));
33+
T v_remote = __builtin_bit_cast(T, v_remote_tmp);
34+
v_local = reduce_f(v_local, v_remote);
35+
}
36+
return v_local;
37+
}
38+
2239
__inline__ __device__ void warpReduceMax(float &val, int &idx)
2340
{
2441
static_assert(64 == WARP_SIZE, "WARP_SIZE == 64");
@@ -63,8 +80,8 @@ namespace aiter
6380
__syncthreads();
6481
}
6582

66-
template <typename DTYPE_I, typename fvec, int NUM_GRP, bool need_renorm>
67-
__global__ void biased_grouped_topk_kernel(
83+
template <typename DTYPE_I, typename fvec, int NUM_GRP, bool need_renorm, bool isBiased, bool isSoftmax>
84+
__global__ void grouped_topk_kernel(
6885
const DTYPE_I *__restrict__ gating_output, // [num_tokens, hidden_size]
6986
const DTYPE_I *__restrict__ correction_bias, // [num_expert]
7087
float *__restrict__ topk_weights, // [num_tokens, topk]
@@ -106,36 +123,112 @@ namespace aiter
106123
fvec *scores_vec = reinterpret_cast<fvec *>(scores);
107124
constexpr uint32_t vec_size = sizeof(fvec) / sizeof(float);
108125

109-
for (int e = threadIdx.x; e < num_experts; e += blockDim.x)
126+
if constexpr (!isSoftmax)
110127
{
111-
float gating = static_cast<float>(gating_output[token_idx * num_experts + e]);
112-
float score = 1.0f / (1.0f + expf(-gating));
113-
scores[e] = score + correction_bias[e];
128+
for (int e = threadIdx.x; e < num_experts; e += blockDim.x)
129+
{
130+
float gating = static_cast<float>(gating_output[token_idx * num_experts + e]);
131+
gating = 1.0f / (1.0f + expf(-gating));
132+
if constexpr (isBiased)
133+
{
134+
gating += correction_bias[e];
135+
}
136+
scores[e] = gating;
137+
}
138+
__syncthreads();
114139
}
140+
else
141+
{
142+
__shared__ float sdata;
143+
float max_val = -INFINITY;
144+
for (int e = threadIdx.x; e < num_experts; e += blockDim.x)
145+
{
115146

147+
float gating = gating_output[token_idx * num_experts + e];
148+
scores[e] = gating;
149+
if (gating > max_val)
150+
{
151+
max_val = gating;
152+
}
153+
}
154+
__syncthreads();
116155
#pragma unroll
117-
for (int g = threadIdx.x; g < NUM_GRP; g += blockDim.x)
118-
{
119-
float max1 = -INFINITY, max2 = -INFINITY;
120-
const int start = g * experts_per_group;
121-
const int end = start + experts_per_group;
156+
for (int i = 0; i < 6; i++)
157+
{
158+
int offset = 1 << i;
159+
float tmp_val = __shfl_down(max_val, offset);
160+
if (tmp_val > max_val)
161+
{
162+
max_val = tmp_val;
163+
}
164+
}
165+
if (threadIdx.x == 0)
166+
{
167+
sdata = max_val;
168+
}
169+
__syncthreads();
170+
max_val = sdata;
171+
float thread_sum = 0.0;
172+
for (int e = threadIdx.x; e < num_experts; e += blockDim.x)
173+
{
174+
scores[e] = expf(scores[e] - max_val);
175+
thread_sum += scores[e];
176+
}
177+
__syncthreads();
178+
thread_sum = wave_reduce(thread_sum, [](float a, float b) { return a + b; });
179+
for (int e = threadIdx.x; e < num_experts; e += blockDim.x)
180+
{
181+
scores[e] /= thread_sum;
182+
}
183+
__syncthreads();
184+
}
122185

123-
for (int e = start; e < end; ++e)
186+
if constexpr (isBiased)
187+
{
188+
#pragma unroll
189+
for (int g = threadIdx.x; g < NUM_GRP; g += blockDim.x)
124190
{
125-
if (scores[e] > max1)
191+
float max1 = -INFINITY, max2 = -INFINITY;
192+
const int start = g * experts_per_group;
193+
const int end = start + experts_per_group;
194+
195+
for (int e = start; e < end; ++e)
126196
{
127-
max2 = max1;
128-
max1 = scores[e];
197+
if (scores[e] > max1)
198+
{
199+
max2 = max1;
200+
max1 = scores[e];
201+
}
202+
else if (scores[e] > max2)
203+
{
204+
max2 = scores[e];
205+
}
129206
}
130-
else if (scores[e] > max2)
207+
group_scores[g] = max1 + max2;
208+
group_mask[g] = false;
209+
}
210+
__syncthreads();
211+
}
212+
else
213+
{
214+
#pragma unroll
215+
for (int g = threadIdx.x; g < NUM_GRP; g += blockDim.x)
216+
{
217+
float max1 = -INFINITY;
218+
const int start = g * experts_per_group;
219+
const int end = start + experts_per_group;
220+
for (int e = start; e < end; ++e)
131221
{
132-
max2 = scores[e];
222+
if (scores[e] > max1)
223+
{
224+
max1 = scores[e];
225+
}
133226
}
227+
group_scores[g] = max1;
228+
group_mask[g] = false;
134229
}
135-
group_scores[g] = max1 + max2;
136-
group_mask[g] = false;
230+
__syncthreads();
137231
}
138-
__syncthreads();
139232

140233
for (int k = 0; k < topk_group; k++)
141234
{
@@ -205,7 +298,10 @@ namespace aiter
205298
max_idx = k;
206299
max_val = scores[max_idx];
207300
}
208-
max_val -= correction_bias[max_idx];
301+
if constexpr (isBiased)
302+
{
303+
max_val -= correction_bias[max_idx];
304+
}
209305
scores[max_idx] = -INFINITY;
210306
topk_indices[k] = max_idx;
211307
topk_values[k] = max_val;
@@ -233,7 +329,7 @@ namespace aiter
233329

234330
for (int k = threadIdx.x; k < topk; k += blockDim.x)
235331
{
236-
topk_weights[token_idx * stride_tk + k] = need_renorm ? topk_values[k] * sum : topk_values[k];
332+
topk_weights[token_idx * stride_tk + k] = topk_values[k] * sum;
237333
topk_ids[token_idx * stride_tk + k] = topk_indices[k];
238334
}
239335
}
@@ -281,18 +377,49 @@ namespace aiter
281377
LAUNCHER4(VEC_F, NUM_GRP, false) \
282378
}
283379

284-
#define LAUNCHER4(VEC_F, NUM_GRP, need_renorm) \
285-
VLLM_DISPATCH_FLOATING_TYPES( \
286-
gating_output.scalar_type(), "biased_grouped_topk_kernel", [&] \
287-
{ aiter::biased_grouped_topk_kernel<scalar_t, VEC_F, NUM_GRP, need_renorm> \
288-
<<<grid, block, shared_mem_size, stream>>>( \
289-
gating_output.data_ptr<scalar_t>(), \
290-
correction_bias.data_ptr<scalar_t>(), \
291-
topk_weights.data_ptr<float>(), \
292-
topk_ids.data_ptr<int>(), \
293-
stride_tk, \
294-
num_experts, \
295-
topk, \
380+
#define LAUNCHER4(VEC_F, NUM_GRP, need_renorm) \
381+
if constexpr (isBiased) \
382+
{ \
383+
LAUNCHER_biased_grouped_topk_kernel(VEC_F, NUM_GRP, need_renorm, true, false) \
384+
} \
385+
else \
386+
{ \
387+
if (isSoftmax) \
388+
{ \
389+
LAUNCHER_grouped_topk_kernel(VEC_F, NUM_GRP, need_renorm, false, true) \
390+
} \
391+
else \
392+
{ \
393+
LAUNCHER_grouped_topk_kernel(VEC_F, NUM_GRP, need_renorm, false, false) \
394+
} \
395+
}
396+
397+
#define LAUNCHER_biased_grouped_topk_kernel(VEC_F, NUM_GRP, need_renorm, isBiased, isSoftmax) \
398+
VLLM_DISPATCH_FLOATING_TYPES( \
399+
gating_output.scalar_type(), "biased_grouped_topk_kernel", [&] \
400+
{ aiter::grouped_topk_kernel<scalar_t, VEC_F, NUM_GRP, need_renorm, isBiased, isSoftmax> \
401+
<<<grid, block, shared_mem_size, stream>>>( \
402+
gating_output.data_ptr<scalar_t>(), \
403+
correction_bias.data_ptr<scalar_t>(), \
404+
topk_weights.data_ptr<float>(), \
405+
topk_ids.data_ptr<int>(), \
406+
stride_tk, \
407+
num_experts, \
408+
topk, \
409+
topk_grp, num_tokens, routed_scaling_factor); });
410+
411+
#define LAUNCHER_grouped_topk_kernel(VEC_F, NUM_GRP, need_renorm, isBiased, isSoftmax) \
412+
VLLM_DISPATCH_FLOATING_TYPES( \
413+
gating_output.scalar_type(), "grouped_topk_kernel", [&] \
414+
{ aiter::grouped_topk_kernel<scalar_t, VEC_F, NUM_GRP, need_renorm, isBiased, isSoftmax> \
415+
<<<grid, block, shared_mem_size, stream>>>( \
416+
gating_output.data_ptr<scalar_t>(), \
417+
nullptr, \
418+
topk_weights.data_ptr<float>(), \
419+
topk_ids.data_ptr<int>(), \
420+
stride_tk, \
421+
num_experts, \
422+
topk, \
296423
topk_grp, num_tokens, routed_scaling_factor); });
297424

298425
void biased_grouped_topk(
@@ -303,8 +430,10 @@ void biased_grouped_topk(
303430
int num_expert_group,
304431
int topk_grp,
305432
bool need_renorm,
306-
const float routed_scaling_factor)
433+
const float routed_scaling_factor=1.)
307434
{
435+
const bool isBiased = true;
436+
bool isSoftmax = false;
308437
int num_tokens = gating_output.size(0);
309438
int num_experts = gating_output.size(1);
310439
int topk = topk_ids.size(1);
@@ -326,6 +455,42 @@ void biased_grouped_topk(
326455

327456
LAUNCH_KERNEL()
328457
}
458+
459+
void grouped_topk(
460+
torch::Tensor &gating_output, // [num_tokens, num_experts]
461+
torch::Tensor &topk_weights, // [num_tokens, topk]
462+
torch::Tensor &topk_ids, // [num_tokens, topk]
463+
int num_expert_group,
464+
int topk_grp,
465+
bool need_renorm,
466+
std::string scoring_func = "softmax",
467+
const float routed_scaling_factor = 1.)
468+
{
469+
TORCH_CHECK((scoring_func == "softmax") || (scoring_func == "sigmoid"), "grouped_topk scoring_func only suppot softmax or sigmoid");
470+
const bool isBiased = false;
471+
bool isSoftmax = scoring_func == "softmax" ? true : false;
472+
int num_tokens = gating_output.size(0);
473+
int num_experts = gating_output.size(1);
474+
int topk = topk_ids.size(1);
475+
size_t stride_tk = topk_ids.stride(0);
476+
auto correction_bias = topk_ids;
477+
TORCH_CHECK(stride_tk == topk_weights.stride(0), "topk_ids.stride(0) == topk_weights.stride(0)");
478+
479+
dim3 grid(num_tokens);
480+
dim3 block(64);
481+
size_t shared_mem_size = (num_experts * sizeof(float) +
482+
num_expert_group * sizeof(float) +
483+
num_expert_group * sizeof(bool) +
484+
topk * sizeof(int) +
485+
topk * sizeof(float) + 255) &
486+
~255;
487+
488+
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
489+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
490+
491+
LAUNCH_KERNEL()
492+
}
493+
329494
#undef LAUNCHER4
330495
#undef LAUNCHER3
331496
#undef LAUNCHER2

0 commit comments

Comments
 (0)