diff --git a/tests/test_fftfreq.cpp b/tests/test_fftfreq.cpp index 812f959..5c3046c 100644 --- a/tests/test_fftfreq.cpp +++ b/tests/test_fftfreq.cpp @@ -1,118 +1,120 @@ /* * SPDX-License-Indentifier: AGPL-3.0-or-later * * Copyright (©) 2016-2023 EPFL (École Polytechnique Fédérale de Lausanne), * Laboratory (LSMS - Laboratoire de Simulation en Mécanique des Solides) * Copyright (©) 2020-2023 Lucas Frérot * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published * by the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see . * */ /* -------------------------------------------------------------------------- */ #include "fft_engine.hh" #include "grid.hh" #include "grid_hermitian.hh" #include "grid_view.hh" #include "mpi_interface.hh" #include "partitioner.hh" #include "test.hh" #include #include #include using namespace tamaas; namespace py = pybind11; /* -------------------------------------------------------------------------- */ TEST(TestFFTEngine, Frequencies1D) { mpi::sequential_guard guard{}; std::array sizes{10}; auto freq{FFTEngine::computeFrequencies(sizes)}; py::module fftfreq{py::module::import("fftfreq")}; std::vector reference(freq.dataSize()); py::array py_arg{static_cast(reference.size()), reference.data(), py::none()}; fftfreq.attr("frequencies1D")(py_arg); EXPECT_TRUE(compare(reference, freq, AreFloatEqual())) << "Non hermitian frequencies are wrong"; auto hfreq{FFTEngine::computeFrequencies(sizes)}; std::iota(reference.begin(), reference.end(), 0); EXPECT_TRUE(compare(reference, hfreq, AreFloatEqual())) << "Hermitian frequencies are wrong"; } /* -------------------------------------------------------------------------- */ TEST(TestFFTEngine, Frequencies2D) { std::array sizes{10, 10}; auto global_sizes = Partitioner<2>::global_size(sizes); auto freq{FFTEngine::computeFrequencies(sizes)}; auto hfreq{FFTEngine::computeFrequencies(sizes)}; auto gathered{Partitioner<2>::gather(freq)}; auto hgathered{Partitioner<2>::gather(hfreq)}; if (mpi::rank() != 0) return; py::module fftfreq{py::module::import("fftfreq")}; std::vector reference(gathered.dataSize()); - py::array py_arg{ - {global_sizes[0], global_sizes[1], 2u}, reference.data(), py::none()}; + py::array py_arg{{static_cast(global_sizes[0]), + static_cast(global_sizes[1]), py::ssize_t{2}}, + reference.data(), + py::none()}; fftfreq.attr("frequencies2D")(py_arg); EXPECT_TRUE(compare(reference, gathered, AreFloatEqual())) << "Non hermitian frequencies are wrong"; fftfreq.attr("hfrequencies2D")(py_arg); EXPECT_TRUE(compare(reference, hgathered, AreFloatEqual())) << "Hermitian frequencies are wrong"; } /* -------------------------------------------------------------------------- */ TEST(TestFFTEngine, RealComponents) { mpi::sequential_guard guard; auto even_even{FFTEngine::realCoefficients<2>({{10, 10}})}; auto even_odd{FFTEngine::realCoefficients<2>({{10, 11}})}; auto odd_even{FFTEngine::realCoefficients<2>({{11, 10}})}; auto odd_odd{FFTEngine::realCoefficients<2>({{11, 11}})}; std::vector even_even_ref{0, 0, 5, 0, 0, 5, 5, 5}; std::vector even_odd_ref{0, 0, 5, 0}; std::vector odd_even_ref{0, 0, 0, 5}; std::vector odd_odd_ref{0, 0}; auto flatten = [](auto v) { std::vector flat; for (auto&& tuple : v) for (auto i : tuple) flat.push_back(i); return flat; }; compare(flatten(even_even), even_even_ref); compare(flatten(even_odd), even_odd_ref); compare(flatten(odd_even), odd_even_ref); compare(flatten(odd_odd), odd_odd_ref); }