diff --git a/hw3-heat-fft/fft.hh b/hw3-heat-fft/fft.hh index 0319abe1..a1b74179 100644 --- a/hw3-heat-fft/fft.hh +++ b/hw3-heat-fft/fft.hh @@ -1,37 +1,86 @@ #ifndef FFT_HH #define FFT_HH /* ------------------------------------------------------ */ #include "matrix.hh" #include "my_types.hh" #include /* ------------------------------------------------------ */ struct FFT { static Matrix transform(Matrix& m); static Matrix itransform(Matrix& m); static Matrix> computeFrequencies(int size); }; /* ------------------------------------------------------ */ inline Matrix FFT::transform(Matrix& m_in) { + int N = m_in.rows(); + + std::cout << "N = " << N << std::endl; + + // Declare/initialize in and out for FFTW + std::vector in(N*N); + std::vector out(N*N); + + fftw_plan plan = fftw_plan_dft_2d( N, + N, + reinterpret_cast(&in[0]), + reinterpret_cast(&out[0]), + FFTW_FORWARD, + FFTW_ESTIMATE ); + + // Fill "in" with input-matrix values + for (auto&& entry : index(m_in)) { + + // Get i and j index of matrix entry + int i = std::get<0>(entry); + int j = std::get<1>(entry); + + // Value of matrix entry at (i,j) + complex val = std::get<2>(entry); + + // Create 1D index k from 2D indices (i,j) + int k = i + j * N; + + // Insert value into in + in[k] = val; + } + + // Execute FFTW + fftw_execute(plan); + + // Fill output-matrix with "out"-values + Matrix m_out(N); + for (auto&& entry : index(m_out)) { + int i = std::get<0>(entry); + int j = std::get<1>(entry); + complex& v = std::get<2>(entry); + int k = i + j * N; + v = out[k]; + } + + // Destroy FFTW + fftw_destroy_plan(plan); + + return m_out; } /* ------------------------------------------------------ */ inline Matrix FFT::itransform(Matrix& m_in) { } /* ------------------------------------------------------ */ /* ------------------------------------------------------ */ inline Matrix> FFT::computeFrequencies(int size) { } #endif // FFT_HH