Page MenuHomec4science

fftransform_cufft.cpp
No OneTemporary

File Metadata

Created
Sun, May 5, 17:49

fftransform_cufft.cpp

/**
* @file
* LICENSE
*
* Copyright (©) 2016-2021 EPFL (École Polytechnique Fédérale de Lausanne),
* Laboratory (LSMS - Laboratoire de Simulation en Mécanique des Solides)
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published
* by the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
/* -------------------------------------------------------------------------- */
#include "fftransform_cufft.hh"
#define MAX_THREADS 1024
namespace tamaas {
template <typename T, UInt dim>
FFTransformCUFFT<T, dim>::FFTransformCUFFT(Grid<T, dim>& real,
GridHermitian<T, dim>& spectral)
: FFTransform<T, dim>(real, spectral) {
const UInt size = real.sizes()[0];
const int components = real.getNbComponents();
// fftw parameters
int rank = dim; // dimension of fft
int n[dim] = {0};
std::copy(real.sizes().begin(), real.sizes().end(),
n); // size of individual fft
int howmany = components; // how many fft to compute
int idist = 1, odist = 1; // components are next to each other in memory
int istride = components, ostride = components;
int *inembed = nullptr, *onembed = nullptr; // row major
cufftResult forward_res =
cufftPlanMany(&forward_plan, rank, n, inembed, istride, idist, onembed,
ostride, odist, CUFFT_D2Z, howmany);
cufftResult backward_res =
cufftPlanMany(&backward_plan, rank, n, onembed, ostride, odist, inembed,
istride, idist, CUFFT_Z2D, howmany);
if (forward_res != CUFFT_SUCCESS)
TAMAAS_EXCEPTION("Forward cufft plan creation failed");
if (backward_res != CUFFT_SUCCESS)
TAMAAS_EXCEPTION("Backward cufft plan creation failed");
}
/* -------------------------------------------------------------------------- */
template <typename T, UInt dim>
FFTransformCUFFT<T, dim>::~FFTransformCUFFT() {
cufftDestroy(forward_plan);
cufftDestroy(backward_plan);
}
/* -------------------------------------------------------------------------- */
template <typename T, UInt dim>
void FFTransformCUFFT<T, dim>::forwardTransform() {
cufftDoubleComplex* out =
reinterpret_cast<cufftDoubleComplex*>(this->spectral.getInternalData());
if (cufftExecD2Z(forward_plan, this->real.getInternalData(), out) !=
CUFFT_SUCCESS)
TAMAAS_EXCEPTION("Forward transform fail");
cudaDeviceSynchronize(); ///< TODO ask the SCITAS guys
}
/* -------------------------------------------------------------------------- */
template <typename T, UInt dim>
void FFTransformCUFFT<T, dim>::backwardTransform() {
cufftDoubleComplex* out =
reinterpret_cast<cufftDoubleComplex*>(this->spectral.getInternalData());
if (cufftExecZ2D(backward_plan, out, this->real.getInternalData()) !=
CUFFT_SUCCESS)
TAMAAS_EXCEPTION("Backward transform fail");
cudaDeviceSynchronize(); ///< TODO ask the SCITAS guys
this->normalize();
}
/* -------------------------------------------------------------------------- */
template class FFTransformCUFFT<Real, 1>;
template class FFTransformCUFFT<Real, 2>;
template class FFTransformCUFFT<Real, 3>;
} // namespace tamaas

Event Timeline