I am looking into the #define
directive recently. And I am confused by the #define usage in the following code example. Anyone could explain how it works?
template <>
__inline__ __device__ void warpReduce<ReduceType::kSum, 2>(float* val_list) {
float val0_tmp, val1_tmp;
#define WarpReduceSumOneStep(a, b) \
val0_tmp = __shfl_xor_sync(FINAL_MASK, *(val_list + 0), a, b); \
val1_tmp = __shfl_xor_sync(FINAL_MASK, *(val_list + 1), a, b); \
*(val_list + 0) += val0_tmp; \
*(val_list + 1) += val1_tmp
WarpReduceSumOneStep(16, 32);
WarpReduceSumOneStep(8, 32);
WarpReduceSumOneStep(4, 32);
WarpReduceSumOneStep(2, 32);
WarpReduceSumOneStep(1, 32);
#undef WarpReduceSumOneStep
}
from my understanding, when WarpReduceSumOneStep(16, 32);
occurs, the compiler substitute it with the blocks between #define
and #undef
, right?