diff --git a/tests/test_thrust.cpp b/tests/test_thrust.cpp index f9a4e57..2c5af7a 100644 --- a/tests/test_thrust.cpp +++ b/tests/test_thrust.cpp @@ -1,21 +1,41 @@ #include "loop.hh" #include "grid.hh" #include #include #include using namespace tamaas; +template +class ApplyFunctor { +public: + ApplyFunctor(const Functor & functor):functor(functor) {} + + template + __host__ __device__ + void operator()(Tuple&& t) { + thrust::system::cuda::detail::bulk_::detail::apply_from_tuple(functor, t); + } + +private: + const Functor & functor; +}; + +template +void loop(Functor&& func, Containers&&... containers) { + auto begin = thrust::make_zip_iterator(thrust::make_tuple(containers.begin()...)); + auto end = thrust::make_zip_iterator(thrust::make_tuple(containers.end()...)); + + thrust::for_each(begin, end, ApplyFunctor(func)); +} + int main() { Grid grid({20}, 2); Grid other({20}, 2); std::iota(grid.begin(), grid.end(), 0); other = 2; - auto begin = thrust::make_zip_iterator(thrust::make_tuple(grid.begin(), other.begin())); - auto end = thrust::make_zip_iterator(thrust::make_tuple(grid.end(), other.end())); - - thrust::for_each(begin, end, [] CUDA_LAMBDA (thrust::tuple t) { thrust::get<0>(t) += thrust::get<1>(t); }); + loop([] CUDA_LAMBDA (Real & x, Real & y) { x += y; }, grid, other); return 0; }