Page MenuHomec4science

fft.hh
No OneTemporary

File Metadata

Created
Tue, May 14, 23:33
#ifndef FFT_HH
#define FFT_HH
/* ------------------------------------------------------ */
#include "matrix.hh"
#include "my_types.hh"
#include <fftw3.h>
/* ------------------------------------------------------ */
struct FFT {
static Matrix<complex> transform(Matrix<complex>& m);
static Matrix<complex> itransform(Matrix<complex>& m);
static Matrix<std::complex<int>> computeFrequencies(int size);
};
/* ------------------------------------------------------ */
inline Matrix<complex> FFT::transform(Matrix<complex>& m_in) {
int N = m_in.rows();
//std::cout << "DEBUG: N = " << N << std::endl;
// Declare/initialize in and out for FFTW
std::vector<complex> in(N*N);
std::vector<complex> out(N*N);
fftw_plan plan = fftw_plan_dft_2d( N,
N,
reinterpret_cast<fftw_complex*>(&in[0]),
reinterpret_cast<fftw_complex*>(&out[0]),
FFTW_FORWARD,
FFTW_ESTIMATE );
// Fill "in" with input-matrix values
for (auto&& entry : index(m_in)) {
// Get i and j index of matrix entry
int i = std::get<0>(entry);
int j = std::get<1>(entry);
// Value of matrix entry at (i,j)
complex val = std::get<2>(entry);
// Create 1D index k from 2D indices (i,j)
int k = i + j * N;
// Insert value into in
in[k] = val;
}
// Execute FFTW
fftw_execute(plan);
// Fill output-matrix with "out"-values
Matrix<complex> m_out(N);
for (auto&& entry : index(m_out)) {
int i = std::get<0>(entry);
int j = std::get<1>(entry);
complex& v = std::get<2>(entry);
int k = i + j * N;
v = out[k];
}
// Destroy FFTW
fftw_destroy_plan(plan);
return m_out;
}
/* ------------------------------------------------------ */
inline Matrix<complex> FFT::itransform(Matrix<complex>& m_in) {
int N = m_in.rows();
// Declare/initialize in and out for FFTW
std::vector<complex> in(N*N);
std::vector<complex> out(N*N);
fftw_plan plan = fftw_plan_dft_2d( N,
N,
reinterpret_cast<fftw_complex*>(&in[0]),
reinterpret_cast<fftw_complex*>(&out[0]),
FFTW_BACKWARD,
FFTW_ESTIMATE );
// Fill "in" with input-matrix values
for (auto&& entry : index(m_in)) {
// Get i and j index of matrix entry
int i = std::get<0>(entry);
int j = std::get<1>(entry);
// Value of matrix entry at (i,j)
complex val = std::get<2>(entry);
// Create 1D index k from 2D indices (i,j)
int k = i + j * N;
// Insert value into in
in[k] = val / (1.0*N*N);
// Note: the divison by N sq follows from the fact that:
// "FFTW computes unnormalized transforms: a transform followed by its
// inverse will result in the original data multiplied by N (or the
// product of the N’s for each dimension, in multi-dimensions)"
// FFTW documentation
}
// Execute FFTW
fftw_execute(plan);
// Fill output-matrix with "out"-values
Matrix<complex> m_out(N);
for (auto&& entry : index(m_out)) {
int i = std::get<0>(entry);
int j = std::get<1>(entry);
complex& v = std::get<2>(entry);
int k = i + j * N;
v = out[k];
}
// Destroy FFTW
fftw_destroy_plan(plan);
return m_out;
}
/* ------------------------------------------------------ */
/* ------------------------------------------------------ */
inline Matrix<std::complex<int>> FFT::computeFrequencies(int size) {
// This function returns a matrix with entries equal to
// Freq^2 = (Freq_x)^2 + (Freq_y)^2
// There is no need to take the square root as only the square of k will be used later
// This vector contains np.fft.fftfreq(size) * size
std::vector<int> Freq_x(size);
if (size % 2){ // if N is odd
int half_size = (size-1)/2 + 1;
for (int i = 0; i < half_size; i++){
Freq_x[i] = i;
}
int j = 0;
half_size--;
for (int i = -half_size; i < 0; i++){
Freq_x[half_size + 1 + j] = i;
j++;
}
} else{ // N is even
int half_size = size/2;
for (int i = 0; i < half_size; i++){
Freq_x[i] = i;
}
int j = 0;
for (int i = -half_size; i < 0; i++){
Freq_x[half_size + j] = i;
j++;
}
}
// DEBUG:
// for (int i = 0; i < size; i++)
// std::cout << Freq_x[i] << " ";
// std::cout << std::endl;
// Fill matrix
Matrix<std::complex<int>> out(size);
for (auto&& entry : index(out)) {
int i = std::get<0>(entry);
int j = std::get<1>(entry);
std::complex<int>& v = std::get<2>(entry);
std::complex<int> Freq_sq = Freq_x[i]*Freq_x[i] + Freq_x[j]*Freq_x[j];
v = Freq_sq;
}
return out;
}
#endif // FFT_HH

Event Timeline