8
8
#include " ec_rocm.h"
9
9
#include " utils/ucc_math_op.h"
10
10
#include < inttypes.h>
11
+ #include < hip/hip_complex.h>
11
12
12
13
#define ROCM_REDUCE_WITH_OP_DEFAULT (NAME, _OP ) \
13
14
template <typename _Type, typename _AlphaType> \
54
55
} \
55
56
}
56
57
58
+ #define ROCM_REDUCE_WITH_COMPLEX_PRODUCT_DEFAULT (NAME, _OP ) \
59
+ template <typename _Type, typename _AlphaType> \
60
+ __global__ void UCC_REDUCE_ROCM_DEFAULT_COMPLEX_##NAME(ucc_eee_task_reduce_t task, \
61
+ uint16_t flags) \
62
+ { \
63
+ size_t start = blockIdx .x * blockDim .x + threadIdx .x ; \
64
+ size_t step = blockDim .x * gridDim .x ; \
65
+ size_t count = task.count ; \
66
+ int n_srcs = task.n_srcs ; \
67
+ const _Type **s = (const _Type **)task.srcs ; \
68
+ _Type * d = (_Type *)task.dst ; \
69
+ size_t i; \
70
+ \
71
+ switch (n_srcs) { \
72
+ case 2 : \
73
+ for (i = start; i < count; i += step) { \
74
+ d[i] = _OP (s[0 ][i], s[1 ][i]); \
75
+ } \
76
+ break ; \
77
+ default : \
78
+ for (i = start; i < count; i += step) { \
79
+ d[i] = _OP (s[0 ][i], s[1 ][i]); \
80
+ for (size_t j = 2 ; j < n_srcs; j++) { \
81
+ d[i] = _OP (d[i], s[j][i]); \
82
+ } \
83
+ } \
84
+ break ; \
85
+ } \
86
+ if (flags & UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA) { \
87
+ for (i = start; i < count; i += step) { \
88
+ d[i] = d[i] * (_AlphaType)task.alpha ; \
89
+ } \
90
+ } \
91
+ }
92
+
57
93
#define ROCM_REDUCE_WITH_OP_STRIDED (NAME, _OP ) \
58
94
template <typename _Type, typename _AlphaType> \
59
95
__global__ void UCC_REDUCE_ROCM_STRIDED_##NAME( \
99
135
} \
100
136
}
101
137
138
+ #define ROCM_REDUCE_WITH_COMPLEX_PRODUCT_STRIDED (NAME, _OP ) \
139
+ template <typename _Type, typename _AlphaType> \
140
+ __global__ void UCC_REDUCE_ROCM_STRIDED_COMPLEX_##NAME( \
141
+ const _Type *s1, const _Type *s2, _Type *d, size_t count, \
142
+ size_t stride, uint16_t n_src2, const bool with_alpha, \
143
+ const double alpha) \
144
+ { \
145
+ size_t start = blockIdx .x * blockDim .x + threadIdx .x ; \
146
+ size_t step = blockDim .x * gridDim .x ; \
147
+ size_t ld = stride / sizeof (_Type); \
148
+ size_t i; \
149
+ \
150
+ ucc_assert_system (stride % sizeof (_Type) == 0 ); \
151
+ switch (n_src2) { \
152
+ case 1 : \
153
+ for (i = start; i < count; i += step) { \
154
+ d[i] = _OP (s1[i], s2[i]); \
155
+ } \
156
+ break ; \
157
+ default : \
158
+ for (i = start; i < count; i += step) { \
159
+ d[i] = _OP (s1[i], s2[i]); \
160
+ for (size_t j = 1 ; j < n_src2; j++) { \
161
+ d[i] = _OP (d[i], s2[i + j * ld]); \
162
+ } \
163
+ } \
164
+ break ; \
165
+ } \
166
+ if (with_alpha) { \
167
+ for (i = start; i < count; i += step) { \
168
+ d[i] = d[i] * (_AlphaType)alpha; \
169
+ } \
170
+ } \
171
+ }
172
+
102
173
ROCM_REDUCE_WITH_OP_DEFAULT (SUM, DO_OP_SUM);
103
174
ROCM_REDUCE_WITH_OP_DEFAULT (PROD, DO_OP_PROD);
175
+ ROCM_REDUCE_WITH_COMPLEX_PRODUCT_DEFAULT (PROD_DOUBLE, hipCmul);
176
+ ROCM_REDUCE_WITH_COMPLEX_PRODUCT_DEFAULT (PROD_FLOAT, hipCmulf);
104
177
ROCM_REDUCE_WITH_OP_DEFAULT (MIN, DO_OP_MIN);
105
178
ROCM_REDUCE_WITH_OP_DEFAULT (MAX, DO_OP_MAX);
106
179
ROCM_REDUCE_WITH_OP_DEFAULT (LAND, DO_OP_LAND);
@@ -112,6 +185,8 @@ ROCM_REDUCE_WITH_OP_DEFAULT(BXOR, DO_OP_BXOR);
112
185
113
186
ROCM_REDUCE_WITH_OP_STRIDED (SUM, DO_OP_SUM);
114
187
ROCM_REDUCE_WITH_OP_STRIDED (PROD, DO_OP_PROD);
188
+ ROCM_REDUCE_WITH_COMPLEX_PRODUCT_STRIDED (PROD_DOUBLE, hipCmul);
189
+ ROCM_REDUCE_WITH_COMPLEX_PRODUCT_STRIDED (PROD_FLOAT, hipCmulf);
115
190
ROCM_REDUCE_WITH_OP_STRIDED (MIN, DO_OP_MIN);
116
191
ROCM_REDUCE_WITH_OP_STRIDED (MAX, DO_OP_MAX);
117
192
ROCM_REDUCE_WITH_OP_STRIDED (LAND, DO_OP_LAND);
@@ -136,6 +211,21 @@ ROCM_REDUCE_WITH_OP_STRIDED(BXOR, DO_OP_BXOR);
136
211
} \
137
212
} while (0 )
138
213
214
+ #define LAUNCH_KERNEL_B (NAME, type, _AlphaType, _task, s, b, t ) \
215
+ do { \
216
+ if (_task->task_type == UCC_EE_EXECUTOR_TASK_REDUCE) { \
217
+ UCC_REDUCE_ROCM_DEFAULT_COMPLEX_##NAME<type, _AlphaType> \
218
+ <<<b, t, 0 , s>>> (_task->reduce , _task->flags ); \
219
+ } else { \
220
+ ucc_eee_task_reduce_strided_t *trs = &_task->reduce_strided ; \
221
+ UCC_REDUCE_ROCM_STRIDED_COMPLEX_##NAME<type, _AlphaType><<<b, t, 0 , s>>> ( \
222
+ (type *)trs->src1 , (type *)trs->src2 , (type *)trs->dst , \
223
+ trs->count , trs->stride , trs->n_src2 , \
224
+ (bool )(_task->flags & UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA), \
225
+ trs->alpha ); \
226
+ } \
227
+ } while (0 )
228
+
139
229
#define LAUNCH_KERNEL (NAME, type, _task, s, b, t ) \
140
230
LAUNCH_KERNEL_A (NAME, type, type, _task, s, b, t)
141
231
@@ -207,15 +297,15 @@ ROCM_REDUCE_WITH_OP_STRIDED(BXOR, DO_OP_BXOR);
207
297
} \
208
298
} while (0 )
209
299
210
- #define DT_REDUCE_FLOAT_COMPLEX (type, _alphaType, _task, _op, s, b, t ) \
300
+ #define DT_REDUCE_FLOAT_COMPLEX (NAME, type, _alphaType, _task, _op, s, b, t ) \
211
301
do { \
212
302
switch (_op) { \
213
303
case UCC_OP_AVG: \
214
304
case UCC_OP_SUM: \
215
- LAUNCH_KERNEL_A (SUM, type , _alphaType, _task, s, b, t); \
305
+ LAUNCH_KERNEL_A (SUM, type , _alphaType, _task, s, b, t); \
216
306
break ; \
217
307
case UCC_OP_PROD: \
218
- LAUNCH_KERNEL_A (PROD , type, _alphaType, _task, s, b, t); \
308
+ LAUNCH_KERNEL_B (NAME , type, _alphaType, _task, s, b, t); \
219
309
break ; \
220
310
default : \
221
311
ec_error (&ucc_ec_rocm.super , \
@@ -299,10 +389,10 @@ ucc_status_t ucc_ec_rocm_reduce(ucc_ee_executor_task_args_t *task,
299
389
return UCC_ERR_NOT_SUPPORTED;
300
390
#endif
301
391
case UCC_DT_FLOAT32_COMPLEX:
302
- DT_REDUCE_FLOAT_COMPLEX (hipFloatComplex, float , task, op, stream, bk, th);
392
+ DT_REDUCE_FLOAT_COMPLEX (PROD_FLOAT, hipFloatComplex, float , task, op, stream, bk, th);
303
393
break ;
304
394
case UCC_DT_FLOAT64_COMPLEX:
305
- DT_REDUCE_FLOAT_COMPLEX (hipDoubleComplex, double , task, op, stream, bk, th);
395
+ DT_REDUCE_FLOAT_COMPLEX (PROD_DOUBLE, hipDoubleComplex, double , task, op, stream, bk, th);
306
396
break ;
307
397
case UCC_DT_BFLOAT16:
308
398
ucc_assert (2 == sizeof (hip_bfloat16));
0 commit comments