diff --git a/src/core/fft_engine.hh b/src/core/fft_engine.hh index 09791fc..8484618 100644 --- a/src/core/fft_engine.hh +++ b/src/core/fft_engine.hh @@ -1,151 +1,154 @@ /** * @file * @section LICENSE * * Copyright (©) 2016-2020 EPFL (École Polytechnique Fédérale de Lausanne), * Laboratory (LSMS - Laboratoire de Simulation en Mécanique des Solides) * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published * by the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see . * */ /* -------------------------------------------------------------------------- */ #ifndef FFT_ENGINE_H #define FFT_ENGINE_H /* -------------------------------------------------------------------------- */ #include "fftw_interface.hh" #include "grid.hh" #include "grid_base.hh" #include "grid_hermitian.hh" #include #include #include /* -------------------------------------------------------------------------- */ namespace tamaas { class FFTEngine { protected: using plan_t = std::pair, fftw::plan>; using key_t = std::basic_string; using complex_t = fftw::helper::complex; public: /// Initialize with flags explicit FFTEngine(unsigned int flags = FFTW_ESTIMATE) : _flags(flags), plans() {} + virtual ~FFTEngine() = default; + /// Execute a forward plan on real and spectral template void forward(Grid& real, GridHermitian& spectral); /// Execute a backward plan on real and spectral template void backward(Grid& real, GridHermitian& spectral); /// Fill a grid with wavevector values in appropriate ordering template static Grid computeFrequencies(const std::array& sizes); /// Get FFTW plan flags unsigned int flags() const { return _flags; } + /// Cast to FFTW complex type + static auto cast(Complex* data) { return reinterpret_cast(data); } protected: /// Return the plans pair for a given transform signature plan_t& getPlans(key_t key); /// Make a transform signature from a pair of grids template static key_t make_key(Grid& real, GridHermitian& spectral); protected: unsigned int _flags; ///< FFTW flags std::map plans; ///< plans corresponding to signatures }; template FFTEngine::key_t FFTEngine::make_key(Grid& real, GridHermitian& spectral) { if (real.getNbComponents() != spectral.getNbComponents()) TAMAAS_EXCEPTION("Components do not match"); auto hermitian_dims = GridHermitian::hermitianDimensions(real.sizes()); if (not std::equal(hermitian_dims.begin(), hermitian_dims.end(), spectral.sizes().begin())) TAMAAS_EXCEPTION("Spectral grid does not have hermitian size"); // Reserve space for dimensions + components + both last strides key_t key(real.getDimension() + 3, 0); thrust::copy_n(real.sizes().begin(), dim, key.begin()); key[dim] = real.getNbComponents(); key[dim + 1] = real.getStrides().back(); key[dim + 2] = spectral.getStrides().back(); return key; } template void FFTEngine::forward(Grid& real, GridHermitian& spectral) { auto& plans = getPlans(make_key(real, spectral)); fftw::execute(plans.first, real.getInternalData(), - reinterpret_cast(spectral.getInternalData())); + cast(spectral.getInternalData())); } template void FFTEngine::backward(Grid& real, GridHermitian& spectral) { auto& plans = getPlans(make_key(real, spectral)); - fftw::execute(plans.second, - reinterpret_cast(spectral.getInternalData()), + fftw::execute(plans.second, cast(spectral.getInternalData()), real.getInternalData()); // Normalize real *= (1. / real.getNbPoints()); } template Grid FFTEngine::computeFrequencies(const std::array& sizes) { // If hermitian is true, we suppose the dimensions of freq are // reduced based on hermitian symetry and that it has dim components auto& n = sizes; Grid freq(n, dim); constexpr UInt dmax = dim - static_cast(hermitian); #pragma omp parallel for // to get rid of a compilation warning from nvcc for (UInt i = 0; i < freq.getNbPoints(); ++i) { UInt index = i; VectorProxy wavevector(freq(index * freq.getNbComponents())); std::array tuple{{0}}; /// Computing tuple from index for (Int d = dim - 1; d >= 0; d--) { tuple[d] = index % n[d]; index -= tuple[d]; index /= n[d]; } if (hermitian) wavevector(dim - 1) = tuple[dim - 1]; for (UInt d = 0; d < dmax; d++) { // Type conversion T td = tuple[d]; T nd = n[d]; T q = (tuple[d] < n[d] / 2) ? td : td - nd; wavevector(d) = q; } } return freq; } } // namespace tamaas #endif