diff --git a/hw3-heat-fft/test_fft.cc b/hw3-heat-fft/test_fft.cc index 760652d0..ff6663ae 100644 --- a/hw3-heat-fft/test_fft.cc +++ b/hw3-heat-fft/test_fft.cc @@ -1,39 +1,59 @@ #include "my_types.hh" #include "fft.hh" #include /*****************************************************************/ TEST(FFT, transform) { UInt N = 512; Matrix m(N); Real k = 2 * M_PI / N; for (auto&& entry : index(m)) { int i = std::get<0>(entry); int j = std::get<1>(entry); auto& val = std::get<2>(entry); val = cos(k * i); } Matrix res = FFT::transform(m); for (auto&& entry : index(res)) { int i = std::get<0>(entry); int j = std::get<1>(entry); auto& val = std::get<2>(entry); if (std::abs(val) > 1e-10) std::cout << i << "," << j << " = " << val << std::endl; if (i == 1 && j == 0) ASSERT_NEAR(std::abs(val), N * N / 2, 1e-10); else if (i == N - 1 && j == 0) ASSERT_NEAR(std::abs(val), N * N / 2, 1e-10); else ASSERT_NEAR(std::abs(val), 0, 1e-10); } } /*****************************************************************/ TEST(FFT, inverse_transform) { + + UInt N = 512; + Matrix m(N); + Real k = 2 * M_PI / N; + for (auto&& entry : index(m)) { + int i = std::get<0>(entry); + int j = std::get<1>(entry); + auto& val = std::get<2>(entry); + val = cos(k * i); + } + Matrix dft_m = FFT::transform(m); + Matrix invdft_dft_m = FFT::itransform(dft_m); + + for (auto&& entry : index(invdft_dft_m)) { + int i = std::get<0>(entry); + int j = std::get<1>(entry); + auto& val = std::get<2>(entry); + ASSERT_NEAR(val.real(), cos(k*i), 1e-10); + } + } /*****************************************************************/