Skip to content

Commit c0b5d1f

Browse files
EC/ROCM: Prod overload issue for HIP complex (#783)
(cherry picked from commit c2a5062) Co-authored-by: Pedram Alizadeh <pmohamma@amd.com>
1 parent a036a5f commit c0b5d1f

File tree

1 file changed

+95
-5
lines changed

1 file changed

+95
-5
lines changed

src/components/ec/rocm/kernel/ec_rocm_reduce.cu

Lines changed: 95 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "ec_rocm.h"
99
#include "utils/ucc_math_op.h"
1010
#include <inttypes.h>
11+
#include <hip/hip_complex.h>
1112

1213
#define ROCM_REDUCE_WITH_OP_DEFAULT(NAME, _OP) \
1314
template <typename _Type, typename _AlphaType> \
@@ -54,6 +55,41 @@
5455
} \
5556
}
5657

58+
#define ROCM_REDUCE_WITH_COMPLEX_PRODUCT_DEFAULT(NAME, _OP) \
59+
template <typename _Type, typename _AlphaType> \
60+
__global__ void UCC_REDUCE_ROCM_DEFAULT_COMPLEX_##NAME(ucc_eee_task_reduce_t task, \
61+
uint16_t flags) \
62+
{ \
63+
size_t start = blockIdx.x * blockDim.x + threadIdx.x; \
64+
size_t step = blockDim.x * gridDim.x; \
65+
size_t count = task.count; \
66+
int n_srcs = task.n_srcs; \
67+
const _Type **s = (const _Type **)task.srcs; \
68+
_Type * d = (_Type *)task.dst; \
69+
size_t i; \
70+
\
71+
switch (n_srcs) { \
72+
case 2: \
73+
for (i = start; i < count; i += step) { \
74+
d[i] = _OP(s[0][i], s[1][i]); \
75+
} \
76+
break; \
77+
default: \
78+
for (i = start; i < count; i += step) { \
79+
d[i] = _OP(s[0][i], s[1][i]); \
80+
for (size_t j = 2; j < n_srcs; j++) { \
81+
d[i] = _OP(d[i], s[j][i]); \
82+
} \
83+
} \
84+
break; \
85+
} \
86+
if (flags & UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA) { \
87+
for (i = start; i < count; i += step) { \
88+
d[i] = d[i] * (_AlphaType)task.alpha; \
89+
} \
90+
} \
91+
}
92+
5793
#define ROCM_REDUCE_WITH_OP_STRIDED(NAME, _OP) \
5894
template <typename _Type, typename _AlphaType> \
5995
__global__ void UCC_REDUCE_ROCM_STRIDED_##NAME( \
@@ -99,8 +135,45 @@
99135
} \
100136
}
101137

138+
#define ROCM_REDUCE_WITH_COMPLEX_PRODUCT_STRIDED(NAME, _OP) \
139+
template <typename _Type, typename _AlphaType> \
140+
__global__ void UCC_REDUCE_ROCM_STRIDED_COMPLEX_##NAME( \
141+
const _Type *s1, const _Type *s2, _Type *d, size_t count, \
142+
size_t stride, uint16_t n_src2, const bool with_alpha, \
143+
const double alpha) \
144+
{ \
145+
size_t start = blockIdx.x * blockDim.x + threadIdx.x; \
146+
size_t step = blockDim.x * gridDim.x; \
147+
size_t ld = stride / sizeof(_Type); \
148+
size_t i; \
149+
\
150+
ucc_assert_system(stride % sizeof(_Type) == 0); \
151+
switch (n_src2) { \
152+
case 1: \
153+
for (i = start; i < count; i += step) { \
154+
d[i] = _OP(s1[i], s2[i]); \
155+
} \
156+
break; \
157+
default: \
158+
for (i = start; i < count; i += step) { \
159+
d[i] = _OP(s1[i], s2[i]); \
160+
for (size_t j = 1; j < n_src2; j++) { \
161+
d[i] = _OP(d[i], s2[i + j * ld]); \
162+
} \
163+
} \
164+
break; \
165+
} \
166+
if (with_alpha) { \
167+
for (i = start; i < count; i += step) { \
168+
d[i] = d[i] * (_AlphaType)alpha; \
169+
} \
170+
} \
171+
}
172+
102173
ROCM_REDUCE_WITH_OP_DEFAULT(SUM, DO_OP_SUM);
103174
ROCM_REDUCE_WITH_OP_DEFAULT(PROD, DO_OP_PROD);
175+
ROCM_REDUCE_WITH_COMPLEX_PRODUCT_DEFAULT(PROD_DOUBLE, hipCmul);
176+
ROCM_REDUCE_WITH_COMPLEX_PRODUCT_DEFAULT(PROD_FLOAT, hipCmulf);
104177
ROCM_REDUCE_WITH_OP_DEFAULT(MIN, DO_OP_MIN);
105178
ROCM_REDUCE_WITH_OP_DEFAULT(MAX, DO_OP_MAX);
106179
ROCM_REDUCE_WITH_OP_DEFAULT(LAND, DO_OP_LAND);
@@ -112,6 +185,8 @@ ROCM_REDUCE_WITH_OP_DEFAULT(BXOR, DO_OP_BXOR);
112185

113186
ROCM_REDUCE_WITH_OP_STRIDED(SUM, DO_OP_SUM);
114187
ROCM_REDUCE_WITH_OP_STRIDED(PROD, DO_OP_PROD);
188+
ROCM_REDUCE_WITH_COMPLEX_PRODUCT_STRIDED(PROD_DOUBLE, hipCmul);
189+
ROCM_REDUCE_WITH_COMPLEX_PRODUCT_STRIDED(PROD_FLOAT, hipCmulf);
115190
ROCM_REDUCE_WITH_OP_STRIDED(MIN, DO_OP_MIN);
116191
ROCM_REDUCE_WITH_OP_STRIDED(MAX, DO_OP_MAX);
117192
ROCM_REDUCE_WITH_OP_STRIDED(LAND, DO_OP_LAND);
@@ -136,6 +211,21 @@ ROCM_REDUCE_WITH_OP_STRIDED(BXOR, DO_OP_BXOR);
136211
} \
137212
} while (0)
138213

214+
#define LAUNCH_KERNEL_B(NAME, type, _AlphaType, _task, s, b, t) \
215+
do { \
216+
if (_task->task_type == UCC_EE_EXECUTOR_TASK_REDUCE) { \
217+
UCC_REDUCE_ROCM_DEFAULT_COMPLEX_##NAME<type, _AlphaType> \
218+
<<<b, t, 0, s>>>(_task->reduce, _task->flags); \
219+
} else { \
220+
ucc_eee_task_reduce_strided_t *trs = &_task->reduce_strided; \
221+
UCC_REDUCE_ROCM_STRIDED_COMPLEX_##NAME<type, _AlphaType><<<b, t, 0, s>>>( \
222+
(type *)trs->src1, (type *)trs->src2, (type *)trs->dst, \
223+
trs->count, trs->stride, trs->n_src2, \
224+
(bool)(_task->flags & UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA), \
225+
trs->alpha); \
226+
} \
227+
} while (0)
228+
139229
#define LAUNCH_KERNEL(NAME, type, _task, s, b, t) \
140230
LAUNCH_KERNEL_A(NAME, type, type, _task, s, b, t)
141231

@@ -207,15 +297,15 @@ ROCM_REDUCE_WITH_OP_STRIDED(BXOR, DO_OP_BXOR);
207297
} \
208298
} while (0)
209299

210-
#define DT_REDUCE_FLOAT_COMPLEX(type, _alphaType, _task, _op, s, b, t) \
300+
#define DT_REDUCE_FLOAT_COMPLEX(NAME, type, _alphaType, _task, _op, s, b, t) \
211301
do { \
212302
switch (_op) { \
213303
case UCC_OP_AVG: \
214304
case UCC_OP_SUM: \
215-
LAUNCH_KERNEL_A(SUM, type , _alphaType, _task, s, b, t); \
305+
LAUNCH_KERNEL_A(SUM, type , _alphaType, _task, s, b, t); \
216306
break; \
217307
case UCC_OP_PROD: \
218-
LAUNCH_KERNEL_A(PROD, type, _alphaType, _task, s, b, t); \
308+
LAUNCH_KERNEL_B(NAME, type, _alphaType, _task, s, b, t); \
219309
break; \
220310
default: \
221311
ec_error(&ucc_ec_rocm.super, \
@@ -299,10 +389,10 @@ ucc_status_t ucc_ec_rocm_reduce(ucc_ee_executor_task_args_t *task,
299389
return UCC_ERR_NOT_SUPPORTED;
300390
#endif
301391
case UCC_DT_FLOAT32_COMPLEX:
302-
DT_REDUCE_FLOAT_COMPLEX(hipFloatComplex, float, task, op, stream, bk, th);
392+
DT_REDUCE_FLOAT_COMPLEX(PROD_FLOAT, hipFloatComplex, float, task, op, stream, bk, th);
303393
break;
304394
case UCC_DT_FLOAT64_COMPLEX:
305-
DT_REDUCE_FLOAT_COMPLEX(hipDoubleComplex, double, task, op, stream, bk, th);
395+
DT_REDUCE_FLOAT_COMPLEX(PROD_DOUBLE, hipDoubleComplex, double, task, op, stream, bk, th);
306396
break;
307397
case UCC_DT_BFLOAT16:
308398
ucc_assert(2 == sizeof(hip_bfloat16));

0 commit comments

Comments
 (0)