Page MenuHomec4science

No OneTemporary

File Metadata

Created
Sat, Nov 23, 23:11
diff --git a/hw3-heat-fft/test_fft.cc b/hw3-heat-fft/test_fft.cc
index cab18dd6..66e6ec87 100644
--- a/hw3-heat-fft/test_fft.cc
+++ b/hw3-heat-fft/test_fft.cc
@@ -1,110 +1,120 @@
#include "my_types.hh"
#include "fft.hh"
#include <gtest/gtest.h>
#include <fstream>
/*****************************************************************/
TEST(FFT, transform) {
UInt N = 512;
Matrix<complex> m(N);
Real k = 2 * M_PI / N;
for (auto&& entry : index(m)) {
int i = std::get<0>(entry);
int j = std::get<1>(entry);
auto& val = std::get<2>(entry);
val = cos(k * i);
}
Matrix<complex> res = FFT::transform(m);
for (auto&& entry : index(res)) {
int i = std::get<0>(entry);
int j = std::get<1>(entry);
auto& val = std::get<2>(entry);
if (std::abs(val) > 1e-10)
std::cout << i << "," << j << " = " << val << std::endl;
if (i == 1 && j == 0)
ASSERT_NEAR(std::abs(val), N * N / 2, 1e-10);
else if (i == N - 1 && j == 0)
ASSERT_NEAR(std::abs(val), N * N / 2, 1e-10);
else
ASSERT_NEAR(std::abs(val), 0, 1e-10);
}
}
/*****************************************************************/
TEST(FFT, inverse_transform) {
+ // this test is based on that the inverse DFT of the forward FFT
+ // should give back the original input
+ // Create matrix (same as above test)
UInt N = 512;
Matrix<complex> m(N);
Real k = 2 * M_PI / N;
for (auto&& entry : index(m)) {
int i = std::get<0>(entry);
int j = std::get<1>(entry);
auto& val = std::get<2>(entry);
val = cos(k * i);
}
+ // Forward DFT
Matrix<complex> dft_m = FFT::transform(m);
+
+ // Inverse DFT of forward DFT
Matrix<complex> invdft_dft_m = FFT::itransform(dft_m);
+ // Check if equal!
for (auto&& entry : index(invdft_dft_m)) {
int i = std::get<0>(entry);
int j = std::get<1>(entry);
auto& val = std::get<2>(entry);
ASSERT_NEAR(val.real(), cos(k*i), 1e-10);
}
}
/*****************************************************************/
TEST(FFT, compute_frequencies){
- double tmp;
- double val;
+ double tmp2;
+ UInt tmp;
+ UInt val;
UInt N;
// These tests are based on values produced by numpy.fft.fftfreq
// which are saved in txt files
// See the python script fft_generate_test_values.py
/////////////// EVEN TEST ///////////////////////
N = 20;
Matrix<std::complex<int>> even_freqs = FFT::computeFrequencies(N);
std::ifstream in_even("fftfreq_test_values_even.txt");
if (!in_even) {
std::cout << "Cannot open even file.\n";
return;
}
for (int j = 0; j < N; j++) {
for (int i = 0; i < N; i++) {
// // DEBUG:
// std::cout << "(i,j) = (" << i << "," << j << ")" << std::endl;
- in_even >> tmp;
+ in_even >> tmp2;
+ tmp = int(tmp2); // because the output is stored in float format
val = even_freqs(i,j).real();
- ASSERT_NEAR(val, tmp, 1e-10);
+ EXPECT_EQ(val, tmp);
}
}
in_even.close();
/////////////// ODD TEST ///////////////////////
N = 19;
Matrix<std::complex<int>> odd_freqs = FFT::computeFrequencies(N);
std::ifstream in_odd("fftfreq_test_values_odd.txt");
if (!in_odd) {
std::cout << "Cannot open odd file.\n";
return;
}
for (int j = 0; j < N; j++) {
for (int i = 0; i < N; i++) {
// // DEBUG:
// std::cout << "(i,j) = (" << i << "," << j << ")" << std::endl;
- in_odd >> tmp;
+ in_odd >> tmp2;
+ tmp = int(tmp2); // because the output is stored in float format
val = odd_freqs(i,j).real();
- ASSERT_NEAR(val, tmp, 1e-10);
+ EXPECT_EQ(val, tmp);
}
}
in_odd.close();
}

Event Timeline