diff --git a/hw3-heat-fft/fft.hh b/hw3-heat-fft/fft.hh index 7f317bac..4bf14268 100644 --- a/hw3-heat-fft/fft.hh +++ b/hw3-heat-fft/fft.hh @@ -1,138 +1,184 @@ #ifndef FFT_HH #define FFT_HH /* ------------------------------------------------------ */ #include "matrix.hh" #include "my_types.hh" #include /* ------------------------------------------------------ */ struct FFT { static Matrix transform(Matrix& m); static Matrix itransform(Matrix& m); static Matrix> computeFrequencies(int size); }; /* ------------------------------------------------------ */ inline Matrix FFT::transform(Matrix& m_in) { int N = m_in.rows(); //std::cout << "DEBUG: N = " << N << std::endl; // Declare/initialize in and out for FFTW std::vector in(N*N); std::vector out(N*N); fftw_plan plan = fftw_plan_dft_2d( N, N, reinterpret_cast(&in[0]), reinterpret_cast(&out[0]), FFTW_FORWARD, FFTW_ESTIMATE ); // Fill "in" with input-matrix values for (auto&& entry : index(m_in)) { // Get i and j index of matrix entry int i = std::get<0>(entry); int j = std::get<1>(entry); // Value of matrix entry at (i,j) complex val = std::get<2>(entry); // Create 1D index k from 2D indices (i,j) int k = i + j * N; // Insert value into in in[k] = val; } // Execute FFTW fftw_execute(plan); // Fill output-matrix with "out"-values Matrix m_out(N); for (auto&& entry : index(m_out)) { int i = std::get<0>(entry); int j = std::get<1>(entry); complex& v = std::get<2>(entry); int k = i + j * N; v = out[k]; } // Destroy FFTW fftw_destroy_plan(plan); return m_out; } /* ------------------------------------------------------ */ inline Matrix FFT::itransform(Matrix& m_in) { int N = m_in.rows(); // Declare/initialize in and out for FFTW std::vector in(N*N); std::vector out(N*N); fftw_plan plan = fftw_plan_dft_2d( N, N, reinterpret_cast(&in[0]), reinterpret_cast(&out[0]), FFTW_BACKWARD, FFTW_ESTIMATE ); // Fill "in" with input-matrix values for (auto&& entry : index(m_in)) { // Get i and j index of matrix entry int i = std::get<0>(entry); int j = std::get<1>(entry); // Value of matrix entry at (i,j) complex val = std::get<2>(entry); // Create 1D index k from 2D indices (i,j) int k = i + j * N; // Insert value into in in[k] = val / (1.0*N*N); // Note: the divison by N sq follows from the fact that: // "FFTW computes unnormalized transforms: a transform followed by its // inverse will result in the original data multiplied by N (or the // product of the N’s for each dimension, in multi-dimensions)" // FFTW documentation } // Execute FFTW fftw_execute(plan); // Fill output-matrix with "out"-values Matrix m_out(N); for (auto&& entry : index(m_out)) { int i = std::get<0>(entry); int j = std::get<1>(entry); complex& v = std::get<2>(entry); int k = i + j * N; v = out[k]; } // Destroy FFTW fftw_destroy_plan(plan); return m_out; } /* ------------------------------------------------------ */ /* ------------------------------------------------------ */ inline Matrix> FFT::computeFrequencies(int size) { + // This function returns a matrix with entries equal to + // Freq^2 = (Freq_x)^2 + (Freq_y)^2 + // There is no need to take the square root as only the square of k will be used later + + // This vector contains np.fft.fftfreq(size) * size + std::vector Freq_x(size); + + if (size % 2){ // if N is odd + int half_size = (size-1)/2 + 1; + for (int i = 0; i < half_size; i++){ + Freq_x[i] = i; + } + int j = 0; + half_size--; + for (int i = -half_size; i < 0; i++){ + Freq_x[half_size + 1 + j] = i; + j++; + } + + } else{ // N is even + int half_size = size/2; + for (int i = 0; i < half_size; i++){ + Freq_x[i] = i; + } + int j = 0; + for (int i = -half_size; i < 0; i++){ + Freq_x[half_size + j] = i; + j++; + } + } + + // DEBUG: + // for (int i = 0; i < size; i++) + // std::cout << Freq_x[i] << " "; + // std::cout << std::endl; + + // Fill matrix + Matrix> out(size); + for (auto&& entry : index(out)) { + int i = std::get<0>(entry); + int j = std::get<1>(entry); + std::complex& v = std::get<2>(entry); + std::complex Freq_sq = Freq_x[i]*Freq_x[i] + Freq_x[j]*Freq_x[j]; + v = Freq_sq; + } + return out; } #endif // FFT_HH