19
19
#define WARP_SIZE 64
20
20
namespace aiter
21
21
{
22
+ template <typename T, typename F>
23
+ __device__ constexpr T wave_reduce (T local, F reduce_f)
24
+ {
25
+ constexpr int reduce_stage = 6 ; // 1<<6=64
26
+ T v_local = local;
27
+ #pragma unroll
28
+ for (int i_stage = 0 ; i_stage < reduce_stage; i_stage++)
29
+ {
30
+ int src_lane = __lane_id () ^ (1 << i_stage);
31
+ int32_t v_remote_tmp =
32
+ __builtin_amdgcn_ds_bpermute (src_lane << 2 , __builtin_bit_cast (int32_t , v_local));
33
+ T v_remote = __builtin_bit_cast (T, v_remote_tmp);
34
+ v_local = reduce_f (v_local, v_remote);
35
+ }
36
+ return v_local;
37
+ }
38
+
22
39
__inline__ __device__ void warpReduceMax (float &val, int &idx)
23
40
{
24
41
static_assert (64 == WARP_SIZE, " WARP_SIZE == 64" );
@@ -63,8 +80,8 @@ namespace aiter
63
80
__syncthreads ();
64
81
}
65
82
66
- template <typename DTYPE_I, typename fvec, int NUM_GRP, bool need_renorm>
67
- __global__ void biased_grouped_topk_kernel (
83
+ template <typename DTYPE_I, typename fvec, int NUM_GRP, bool need_renorm, bool isBiased, bool isSoftmax >
84
+ __global__ void grouped_topk_kernel (
68
85
const DTYPE_I *__restrict__ gating_output, // [num_tokens, hidden_size]
69
86
const DTYPE_I *__restrict__ correction_bias, // [num_expert]
70
87
float *__restrict__ topk_weights, // [num_tokens, topk]
@@ -106,36 +123,112 @@ namespace aiter
106
123
fvec *scores_vec = reinterpret_cast <fvec *>(scores);
107
124
constexpr uint32_t vec_size = sizeof (fvec) / sizeof (float );
108
125
109
- for ( int e = threadIdx . x ; e < num_experts; e += blockDim . x )
126
+ if constexpr (!isSoftmax )
110
127
{
111
- float gating = static_cast <float >(gating_output[token_idx * num_experts + e]);
112
- float score = 1 .0f / (1 .0f + expf (-gating));
113
- scores[e] = score + correction_bias[e];
128
+ for (int e = threadIdx .x ; e < num_experts; e += blockDim .x )
129
+ {
130
+ float gating = static_cast <float >(gating_output[token_idx * num_experts + e]);
131
+ gating = 1 .0f / (1 .0f + expf (-gating));
132
+ if constexpr (isBiased)
133
+ {
134
+ gating += correction_bias[e];
135
+ }
136
+ scores[e] = gating;
137
+ }
138
+ __syncthreads ();
114
139
}
140
+ else
141
+ {
142
+ __shared__ float sdata;
143
+ float max_val = -INFINITY;
144
+ for (int e = threadIdx .x ; e < num_experts; e += blockDim .x )
145
+ {
115
146
147
+ float gating = gating_output[token_idx * num_experts + e];
148
+ scores[e] = gating;
149
+ if (gating > max_val)
150
+ {
151
+ max_val = gating;
152
+ }
153
+ }
154
+ __syncthreads ();
116
155
#pragma unroll
117
- for (int g = threadIdx .x ; g < NUM_GRP; g += blockDim .x )
118
- {
119
- float max1 = -INFINITY, max2 = -INFINITY;
120
- const int start = g * experts_per_group;
121
- const int end = start + experts_per_group;
156
+ for (int i = 0 ; i < 6 ; i++)
157
+ {
158
+ int offset = 1 << i;
159
+ float tmp_val = __shfl_down (max_val, offset);
160
+ if (tmp_val > max_val)
161
+ {
162
+ max_val = tmp_val;
163
+ }
164
+ }
165
+ if (threadIdx .x == 0 )
166
+ {
167
+ sdata = max_val;
168
+ }
169
+ __syncthreads ();
170
+ max_val = sdata;
171
+ float thread_sum = 0.0 ;
172
+ for (int e = threadIdx .x ; e < num_experts; e += blockDim .x )
173
+ {
174
+ scores[e] = expf (scores[e] - max_val);
175
+ thread_sum += scores[e];
176
+ }
177
+ __syncthreads ();
178
+ thread_sum = wave_reduce (thread_sum, [](float a, float b) { return a + b; });
179
+ for (int e = threadIdx .x ; e < num_experts; e += blockDim .x )
180
+ {
181
+ scores[e] /= thread_sum;
182
+ }
183
+ __syncthreads ();
184
+ }
122
185
123
- for (int e = start; e < end; ++e)
186
+ if constexpr (isBiased)
187
+ {
188
+ #pragma unroll
189
+ for (int g = threadIdx .x ; g < NUM_GRP; g += blockDim .x )
124
190
{
125
- if (scores[e] > max1)
191
+ float max1 = -INFINITY, max2 = -INFINITY;
192
+ const int start = g * experts_per_group;
193
+ const int end = start + experts_per_group;
194
+
195
+ for (int e = start; e < end; ++e)
126
196
{
127
- max2 = max1;
128
- max1 = scores[e];
197
+ if (scores[e] > max1)
198
+ {
199
+ max2 = max1;
200
+ max1 = scores[e];
201
+ }
202
+ else if (scores[e] > max2)
203
+ {
204
+ max2 = scores[e];
205
+ }
129
206
}
130
- else if (scores[e] > max2)
207
+ group_scores[g] = max1 + max2;
208
+ group_mask[g] = false ;
209
+ }
210
+ __syncthreads ();
211
+ }
212
+ else
213
+ {
214
+ #pragma unroll
215
+ for (int g = threadIdx .x ; g < NUM_GRP; g += blockDim .x )
216
+ {
217
+ float max1 = -INFINITY;
218
+ const int start = g * experts_per_group;
219
+ const int end = start + experts_per_group;
220
+ for (int e = start; e < end; ++e)
131
221
{
132
- max2 = scores[e];
222
+ if (scores[e] > max1)
223
+ {
224
+ max1 = scores[e];
225
+ }
133
226
}
227
+ group_scores[g] = max1;
228
+ group_mask[g] = false ;
134
229
}
135
- group_scores[g] = max1 + max2;
136
- group_mask[g] = false ;
230
+ __syncthreads ();
137
231
}
138
- __syncthreads ();
139
232
140
233
for (int k = 0 ; k < topk_group; k++)
141
234
{
@@ -205,7 +298,10 @@ namespace aiter
205
298
max_idx = k;
206
299
max_val = scores[max_idx];
207
300
}
208
- max_val -= correction_bias[max_idx];
301
+ if constexpr (isBiased)
302
+ {
303
+ max_val -= correction_bias[max_idx];
304
+ }
209
305
scores[max_idx] = -INFINITY;
210
306
topk_indices[k] = max_idx;
211
307
topk_values[k] = max_val;
@@ -233,7 +329,7 @@ namespace aiter
233
329
234
330
for (int k = threadIdx .x ; k < topk; k += blockDim .x )
235
331
{
236
- topk_weights[token_idx * stride_tk + k] = need_renorm ? topk_values[k] * sum : topk_values[k] ;
332
+ topk_weights[token_idx * stride_tk + k] = topk_values[k] * sum;
237
333
topk_ids[token_idx * stride_tk + k] = topk_indices[k];
238
334
}
239
335
}
@@ -281,18 +377,49 @@ namespace aiter
281
377
LAUNCHER4 (VEC_F, NUM_GRP, false ) \
282
378
}
283
379
284
- #define LAUNCHER4 (VEC_F, NUM_GRP, need_renorm ) \
285
- VLLM_DISPATCH_FLOATING_TYPES ( \
286
- gating_output.scalar_type(), "biased_grouped_topk_kernel", [&] \
287
- { aiter::biased_grouped_topk_kernel<scalar_t , VEC_F, NUM_GRP, need_renorm> \
288
- <<<grid, block, shared_mem_size, stream>>> ( \
289
- gating_output.data_ptr <scalar_t >(), \
290
- correction_bias.data_ptr <scalar_t >(), \
291
- topk_weights.data_ptr <float >(), \
292
- topk_ids.data_ptr <int >(), \
293
- stride_tk, \
294
- num_experts, \
295
- topk, \
380
+ #define LAUNCHER4 (VEC_F, NUM_GRP, need_renorm ) \
381
+ if constexpr (isBiased) \
382
+ { \
383
+ LAUNCHER_biased_grouped_topk_kernel (VEC_F, NUM_GRP, need_renorm, true , false ) \
384
+ } \
385
+ else \
386
+ { \
387
+ if (isSoftmax) \
388
+ { \
389
+ LAUNCHER_grouped_topk_kernel (VEC_F, NUM_GRP, need_renorm, false , true ) \
390
+ } \
391
+ else \
392
+ { \
393
+ LAUNCHER_grouped_topk_kernel (VEC_F, NUM_GRP, need_renorm, false , false ) \
394
+ } \
395
+ }
396
+
397
+ #define LAUNCHER_biased_grouped_topk_kernel (VEC_F, NUM_GRP, need_renorm, isBiased, isSoftmax ) \
398
+ VLLM_DISPATCH_FLOATING_TYPES ( \
399
+ gating_output.scalar_type(), "biased_grouped_topk_kernel", [&] \
400
+ { aiter::grouped_topk_kernel<scalar_t , VEC_F, NUM_GRP, need_renorm, isBiased, isSoftmax> \
401
+ <<<grid, block, shared_mem_size, stream>>> ( \
402
+ gating_output.data_ptr <scalar_t >(), \
403
+ correction_bias.data_ptr <scalar_t >(), \
404
+ topk_weights.data_ptr <float >(), \
405
+ topk_ids.data_ptr <int >(), \
406
+ stride_tk, \
407
+ num_experts, \
408
+ topk, \
409
+ topk_grp, num_tokens, routed_scaling_factor); });
410
+
411
+ #define LAUNCHER_grouped_topk_kernel (VEC_F, NUM_GRP, need_renorm, isBiased, isSoftmax ) \
412
+ VLLM_DISPATCH_FLOATING_TYPES ( \
413
+ gating_output.scalar_type(), "grouped_topk_kernel", [&] \
414
+ { aiter::grouped_topk_kernel<scalar_t , VEC_F, NUM_GRP, need_renorm, isBiased, isSoftmax> \
415
+ <<<grid, block, shared_mem_size, stream>>> ( \
416
+ gating_output.data_ptr <scalar_t >(), \
417
+ nullptr , \
418
+ topk_weights.data_ptr <float >(), \
419
+ topk_ids.data_ptr <int >(), \
420
+ stride_tk, \
421
+ num_experts, \
422
+ topk, \
296
423
topk_grp, num_tokens, routed_scaling_factor); });
297
424
298
425
void biased_grouped_topk (
@@ -303,8 +430,10 @@ void biased_grouped_topk(
303
430
int num_expert_group,
304
431
int topk_grp,
305
432
bool need_renorm,
306
- const float routed_scaling_factor)
433
+ const float routed_scaling_factor= 1 . )
307
434
{
435
+ const bool isBiased = true ;
436
+ bool isSoftmax = false ;
308
437
int num_tokens = gating_output.size (0 );
309
438
int num_experts = gating_output.size (1 );
310
439
int topk = topk_ids.size (1 );
@@ -326,6 +455,42 @@ void biased_grouped_topk(
326
455
327
456
LAUNCH_KERNEL ()
328
457
}
458
+
459
+ void grouped_topk (
460
+ torch::Tensor &gating_output, // [num_tokens, num_experts]
461
+ torch::Tensor &topk_weights, // [num_tokens, topk]
462
+ torch::Tensor &topk_ids, // [num_tokens, topk]
463
+ int num_expert_group,
464
+ int topk_grp,
465
+ bool need_renorm,
466
+ std::string scoring_func = " softmax" ,
467
+ const float routed_scaling_factor = 1 .)
468
+ {
469
+ TORCH_CHECK ((scoring_func == " softmax" ) || (scoring_func == " sigmoid" ), " grouped_topk scoring_func only suppot softmax or sigmoid" );
470
+ const bool isBiased = false ;
471
+ bool isSoftmax = scoring_func == " softmax" ? true : false ;
472
+ int num_tokens = gating_output.size (0 );
473
+ int num_experts = gating_output.size (1 );
474
+ int topk = topk_ids.size (1 );
475
+ size_t stride_tk = topk_ids.stride (0 );
476
+ auto correction_bias = topk_ids;
477
+ TORCH_CHECK (stride_tk == topk_weights.stride (0 ), " topk_ids.stride(0) == topk_weights.stride(0)" );
478
+
479
+ dim3 grid (num_tokens);
480
+ dim3 block (64 );
481
+ size_t shared_mem_size = (num_experts * sizeof (float ) +
482
+ num_expert_group * sizeof (float ) +
483
+ num_expert_group * sizeof (bool ) +
484
+ topk * sizeof (int ) +
485
+ topk * sizeof (float ) + 255 ) &
486
+ ~255 ;
487
+
488
+ const at::cuda::OptionalCUDAGuard device_guard (device_of (gating_output));
489
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
490
+
491
+ LAUNCH_KERNEL ()
492
+ }
493
+
329
494
#undef LAUNCHER4
330
495
#undef LAUNCHER3
331
496
#undef LAUNCHER2
0 commit comments