I meet a problem when I apply template techniques in kernel wrapper functions.
Here is the codes in my original minds:
//----------------------------------------
// cuda_demo.cuh
template<typename T>
void kernel_wrapper(T param);
//----------------------------------------
// cuda_demo.cu
#include <cuda.h>
#include <cuda_runtime.h>
#include "cuda_demo.cuh"
template<typename T>
__global__ void my_kernel(T param) {
// do something
}
template<typename T>
void kernel_wrapper(T param) {
my_kernel<<<1,1>>>(param);
}
//----------------------------------------
// main.cpp
#include "cuda_demo.cuh"
int main() {
int param = 10;
kernel_wrapper(param);
return 0;
}
Soon I find that templates should be implemented in the header file(see Why can templates only be implemented in the header file?).
And I get two solutions from that, the common one is "to write the template declaration in a header file, then implement the class in an implementation file (for example .tpp), and include this implementation file at the end of the header".
So I change the codes:
//----------------------------------------
// cuda_demo.cuh
template<typename T>
void kernel_wrapper(T param);
#include "cuda_demo.cu"
//----------------------------------------
// cuda_demo.cu
#include <cuda.h>
#include <cuda_runtime.h>
template<typename T>
__global__ void my_kernel(T param) {
// do something
}
template<typename T>
void kernel_wrapper(T param) {
my_kernel<<<1,1>>>(param);
}
The compiler gives me the following error:
error: expected primary-expression before < token
my_kernel<<<1,1>>>(param);
The same error occurs when I put all cuda codes in "cuda_demo.cuh".
Then I tried the second solution as following:
//----------------------------------------
// cuda_demo.cuh
template<typename T>
void kernel_wrapper(T param);
//----------------------------------------
// cuda_demo.cu
#include <cuda.h>
#include <cuda_runtime.h>
#include "cuda_demo.cuh"
template<typename T>
__global__ void my_kernel(T param) {
// do something
}
template<typename T>
void kernel_wrapper(T param) {
my_kernel<<<1,1>>>(param);
}
template void kernel_wrapper<int>(int param);
This one works well! But in my project, 'T' is not a simple type, which may be recursive like
Class_1<Class_2<Class_3<...>>>,
Which means I cannot figure out the specific type of 'T' in advance.
Does somebody know how to solve that?
Thanks.