Page MenuHomec4science

test_mpi.cpp
No OneTemporary

File Metadata

Created
Sat, Nov 23, 23:11

test_mpi.cpp

/*
* SPDX-License-Indentifier: AGPL-3.0-or-later
*
* Copyright (©) 2016-2023 EPFL (École Polytechnique Fédérale de Lausanne),
* Laboratory (LSMS - Laboratoire de Simulation en Mécanique des Solides)
* Copyright (©) 2020-2023 Lucas Frérot
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published
* by the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
/* -------------------------------------------------------------------------- */
#include "fftw/mpi/fftw_mpi_engine.hh"
#include "grid.hh"
#include "grid_hermitian.hh"
#include "grid_view.hh"
#include "partitioner.hh"
#include "test.hh"
#include <fftw3-mpi.h>
using namespace tamaas;
using fft = fftw::helper<Real>;
/* -------------------------------------------------------------------------- */
template <typename T>
struct span {
T* ptr;
std::size_t size;
~span() { fftw::free(ptr); }
const T* begin() const { return ptr; }
const T* end() const { return ptr + size; }
T* begin() { return ptr; }
T* end() { return ptr + size; }
operator T*() { return ptr; }
};
/* -------------------------------------------------------------------------- */
TEST(TestMPIInterface, SequentialGuard) {
mpi::sequential_guard guard;
EXPECT_EQ(mpi::rank(), 0);
EXPECT_EQ(mpi::size(), 1);
ASSERT_EQ(mpi::comm::world()._comm, MPI_COMM_SELF);
}
/* -------------------------------------------------------------------------- */
TEST(TestFFTWInterface, MPISizes) {
const std::ptrdiff_t N0 = 100, N1 = 100;
std::ptrdiff_t local_n0, local_n0_start;
auto alloc_local = fftw_mpi_local_size_2d(N0, N1 / 2 + 1, mpi::comm::world(),
&local_n0, &local_n0_start);
const std::ptrdiff_t N[] = {N0, N1 / 2 + 1};
auto sizes = fftw::mpi::local_size_many(2, N, 1);
ASSERT_EQ(std::get<0>(sizes), alloc_local);
ASSERT_EQ(std::get<1>(sizes), local_n0);
ASSERT_EQ(std::get<2>(sizes), local_n0_start);
}
/* -------------------------------------------------------------------------- */
TEST(TestPartitioner, LocalSizes) {
const std::ptrdiff_t N0 = 100, N1 = 100;
std::ptrdiff_t local_n0, local_n0_start;
fftw_mpi_local_size_2d(N0, N1 / 2 + 1, mpi::comm::world(), &local_n0,
&local_n0_start);
auto local_size = Partitioner<2>::local_size({N0, N1 / 2 + 1});
ASSERT_EQ(local_size[0], local_n0);
ASSERT_EQ(local_size[1], N1 / 2 + 1);
ASSERT_EQ(Partitioner<2>::local_offset({N0, N1 / 2 + 1}), local_n0_start);
}
/* -------------------------------------------------------------------------- */
TEST(TestPartitioner, LocalOffsetInGrid) {
std::array<UInt, 2> N = {10, 10};
std::array<UInt, 2> Nglobal = {mpi::allreduce(N.front()), N.back()};
Grid<UInt, 2> local(N, 1), global(Nglobal, 1);
auto offset = Partitioner<2>::local_offset(local);
auto local_n0_start = Partitioner<2>::local_offset(Nglobal);
ASSERT_EQ(offset, &global(local_n0_start, 0) - &global(0, 0));
}
/* -------------------------------------------------------------------------- */
TEST(TestPartitioner, Scatter) {
std::array<UInt, 2> N = {10, 10};
Grid<Real, 2> global(N, 3);
if (mpi::rank() != 0)
global.resize({0, 0});
std::iota(global.begin(), global.end(), 0);
auto local = Partitioner<2>::scatter(global);
auto gathered = Partitioner<2>::gather(local);
if (mpi::rank() != 0)
return;
ASSERT_TRUE(compare(gathered, global, AreFloatEqual()));
}
/* -------------------------------------------------------------------------- */
TEST(TestFFTWMPIEngine, BasicTransform) {
const std::ptrdiff_t N0 = 60, N1 = 60;
auto local_size = Partitioner<2>::local_size({N0, N1 / 2 + 1});
Grid<Real, 2> real({local_size[0], N1}, 1);
GridHermitian<Real, 2> spectral(local_size, 1);
auto offset = Partitioner<2>::local_offset(real);
std::iota(real.begin(), real.end(), offset);
FFTWMPIEngine().forward(real, spectral);
auto gathered = Partitioner<2>::gather(spectral);
if (mpi::rank() != 0)
return;
Grid<Real, 2> real_global({N0, N1}, 1);
GridHermitian<Real, 2> solution({N0, N1 / 2 + 1}, 1);
std::iota(real_global.begin(), real_global.end(), 0);
FFTWEngine().forward(real_global, solution);
// solution -= gathered;
// Logger().get(LogLevel::info) << solution << '\n';
// Increased error here because the iota input is bad:
// because operations are not the same in parallel,
// some parts of the output that should be close to zeros
// have different values in serial and parallel cases
ASSERT_TRUE(compare(gathered, solution, AreComplexEqual{1e-11}));
}
/* -------------------------------------------------------------------------- */
TEST(TestFFTWMPIEngine, BackwardsTransform) {
const std::ptrdiff_t N0 = 128, N1 = 128;
auto local_size = Partitioner<2>::local_size({N0, N1 / 2 + 1});
Grid<Real, 2> real({local_size[0], N1}, 1);
GridHermitian<Real, 2> spectral(local_size, 1);
real = 1.;
FFTWMPIEngine engine;
engine.forward(real, spectral);
real = 0;
engine.backward(real, spectral);
auto gathered = Partitioner<2>::gather(real);
if (mpi::rank() != 0)
return;
Grid<Real, 2> reference({N0, N1}, 1);
reference = 1.;
ASSERT_TRUE(compare(gathered, reference, AreFloatEqual()));
}
TEST(TestFFTWMPIEngine, ComponentsTransform) {
const std::ptrdiff_t N0 = 10, N1 = 10;
auto local_size = Partitioner<2>::local_size({N0, N1 / 2 + 1});
Grid<Real, 2> real({local_size[0], N1}, 2);
GridHermitian<Real, 2> spectral(local_size, 2);
real = 1;
FFTWMPIEngine().forward(real, spectral);
auto gathered = Partitioner<2>::gather(spectral);
if (mpi::rank() != 0)
return;
Grid<Real, 2> real_global({N0, N1}, 2);
GridHermitian<Real, 2> solution({N0, N1 / 2 + 1}, 2);
real_global = 1;
FFTWEngine().forward(real_global, solution);
ASSERT_TRUE(compare(gathered, solution, AreComplexEqual()));
}

Event Timeline