34 #error You need to have activated FFTW3 (WITH_FFTW3) to include this file.
45 #include <boost/math/constants/constants.hpp>
47 #include "DGtalCatch.h"
48 #include "ConfigTest.h"
50 #include "DGtal/math/RealFFT.h"
51 #include "DGtal/images/ImageContainerBySTLVector.h"
52 #include "DGtal/io/readers/PGMReader.h"
53 #include "DGtal/io/writers/PGMWriter.h"
54 #include "DGtal/base/BasicFunctors.h"
56 using namespace DGtal;
69 using FFT = RealFFT< TDomain, TValue >;
71 INFO(
"Initializing RealFFT." );
72 FFT fft( anImage.
domain() );
74 INFO(
"Copying data from the image." );
75 auto spatial_image = fft.getSpatialImage();
76 std::copy( anImage.cbegin(), anImage.cend(), spatial_image.begin() );
78 INFO(
"Forward transformation." );
79 fft.forwardFFT( FFTW_ESTIMATE );
81 INFO(
"Checking modification of the input image." );
82 REQUIRE( ! std::equal( anImage.cbegin(), anImage.cend(), spatial_image.cbegin() ) );
84 INFO(
"Backward transformation." );
85 fft.backwardFFT( FFTW_ESTIMATE );
87 INFO(
"Comparing result with original image." );
88 const auto eps = 100 * std::numeric_limits<TValue>::epsilon() * std::log( anImage.
domain().size() );
91 for (
auto it = spatial_image.cbegin(); it != spatial_image.cend(); ++it )
93 if ( std::abs( *it - anImage( it.getPoint() ) ) > eps *
std::max( *it, TValue(1) ) )
94 FAIL(
"Approximation failed: " << *it <<
" - " << anImage( it.getPoint() ) <<
" = " << (*it - anImage( it.getPoint() ) ) );
104 template <
typename TIterator>
107 auto norm_max = std::norm(*it);
110 for ( ; it != it_end; ++it )
111 if ( std::norm(*it) > norm_max )
113 norm_max = std::norm(*it);
131 typedef RealFFT< TDomain, TValue > FFT;
134 const TValue pi = boost::math::constants::pi<TValue>();
135 const TValue freq = 5;
136 const TValue phase = pi/4;
138 INFO(
"Checking image size." );
139 REQUIRE( anImage.
extent()[ TDomain::dimension-1 ] >= 2*std::abs(freq) );
141 INFO(
"Initializing RealFFT." );
142 FFT fft( anImage.
domain(), RealPoint::zero, RealPoint::diagonal(1) );
144 INFO(
"Initializing spatial data." );
145 auto spatial_image = fft.getSpatialImage();
146 for (
auto it = spatial_image.begin(); it != spatial_image.end(); ++it )
148 const auto pt = fft.calcScaledSpatialCoords( it.getPoint() );
149 REQUIRE( fft.calcNativeSpatialCoords( pt ) == it.getPoint() );
151 *it = std::cos( 2*pi * freq * pt[ TDomain::dimension - 1 ] + phase );
154 INFO(
"Forward transformation." );
155 fft.forwardFFT( FFTW_ESTIMATE );
157 INFO(
"Finding maximal frequency..." );
158 const auto freq_image = fft.getFreqImage();
159 const auto it_max =
findMaxNorm( freq_image.cbegin(), freq_image.cend() );
160 const auto pt_max = it_max.getPoint();
161 INFO(
"\tat " << pt_max <<
" with value " << *it_max );
163 INFO(
"Checks maximal frequency on unit domain." );
165 auto freq_pt = fft.calcScaledFreqCoords( it_max.getPoint() );
166 auto freq_val = fft.calcScaledFreqValue( freq_pt, *it_max );
169 REQUIRE( fft.calcNativeFreqCoords( freq_pt, applyConj ) == it_max.getPoint() );
171 REQUIRE( std::norm( fft.calcNativeFreqValue( freq_pt, freq_val ) - *it_max ) == Approx(0.).
scale(std::norm(*it_max)) );
172 REQUIRE( std::norm( fft.getScaledFreqValue( freq_pt ) - freq_val ) == Approx(0.).
scale(std::norm(freq_val)) );
174 if ( freq_pt[ TDomain::dimension-1 ] * freq < 0 )
177 freq_val = std::conj( freq_val );
180 REQUIRE( ( freq_pt - RealPoint::base( TDomain::dimension-1, freq ) ).norm() == Approx( 0 ).
scale(freq_pt.norm()) );
181 REQUIRE( ( std::fmod( std::fmod( std::arg( freq_val ) - phase, 2*pi ) + 3*pi, 2*pi ) - pi ) == Approx( 0 ).
scale(pi) );
185 INFO(
"Checks maximal frequency on translated unit domain." );
187 fft.setScaledSpatialLowerBound( RealPoint::diagonal( 1. / (4*freq) ) );
189 auto freq_pt = fft.calcScaledFreqCoords( it_max.getPoint() );
190 auto freq_val = fft.calcScaledFreqValue( freq_pt, *it_max );
193 REQUIRE( fft.calcNativeFreqCoords( freq_pt, applyConj ) == it_max.getPoint() );
195 REQUIRE( std::norm( fft.calcNativeFreqValue( freq_pt, freq_val ) - *it_max ) == Approx(0.).
scale(std::norm(*it_max)) );
196 REQUIRE( std::norm( fft.getScaledFreqValue( freq_pt ) - freq_val ) == Approx(0.).
scale(std::norm(freq_val)) );
198 if ( freq_pt[ TDomain::dimension-1 ] * freq < 0 )
201 freq_val = std::conj( freq_val );
204 REQUIRE( ( freq_pt - RealPoint::base( TDomain::dimension-1, freq ) ).norm() == Approx( 0 ).
scale(freq_pt.norm()) );
205 REQUIRE( ( std::fmod( std::fmod( std::arg( freq_val ) - phase + pi/2, 2*pi) + 3*pi, 2*pi ) - pi ) == Approx( 0 ).
scale(pi) );
208 INFO(
"Checks maximal frequency on translated initial domain." );
210 const RealPoint shift = RealPoint::diagonal( 3. );
212 fft.setScaledSpatialExtent( anImage.
extent() );
213 fft.setScaledSpatialLowerBound( shift );
215 auto freq_pt = fft.calcScaledFreqCoords( it_max.getPoint() );
216 auto freq_val = fft.calcScaledFreqValue( freq_pt, *it_max );
219 REQUIRE( fft.calcNativeFreqCoords( freq_pt, applyConj ) == it_max.getPoint() );
221 REQUIRE( std::norm( fft.calcNativeFreqValue( freq_pt, freq_val ) - *it_max ) == Approx(0.).
scale(std::norm(*it_max)) );
222 REQUIRE( std::norm( fft.getScaledFreqValue( freq_pt ) - freq_val ) == Approx(0.).
scale(std::norm(freq_val)) );
224 if ( freq_pt[ TDomain::dimension-1 ] * freq < 0 )
227 freq_val = std::conj( freq_val );
230 const auto scaled_factor = freq/anImage.
extent()[ TDomain::dimension-1 ];
231 REQUIRE( ( freq_pt - RealPoint::base( TDomain::dimension-1, scaled_factor ) ).norm() == Approx( 0 ).
scale(freq_pt.norm()) );
232 REQUIRE( ( std::fmod( std::fmod( std::arg( freq_val ) - phase + 2*pi*scaled_factor*shift[TDomain::dimension-1], 2*pi ) + 3*pi, 2*pi ) - pi ) == Approx( 0 ).
scale(pi) );
247 using FFT = RealFFT< TDomain, TValue >;
254 INFO(
"Initializing RealFFT." );
257 FFT shifted_fft( shifted_domain );
259 INFO(
"Pre-creating plan." );
260 fft.createPlan( FFTW_MEASURE, FFTW_FORWARD );
262 INFO(
"Copying data from the image." );
263 auto spatial_image = fft.getSpatialImage();
264 std::copy( anImage.cbegin(), anImage.cend(), spatial_image.begin() );
266 auto shifted_spatial_image = shifted_fft.getSpatialImage();
267 const auto spatial_extent = shifted_fft.getSpatialExtent();
268 for (
auto it = shifted_spatial_image.begin(); it != shifted_spatial_image.end(); ++it )
271 Point pt = it.getPoint();
272 for (
typename Point::Dimension i = 0; i < Point::dimension; ++i )
274 pt[ i ] -= anImage.
extent()[ i ];
279 INFO(
"Forward transformation (forcing plan re-use)." );
280 fft.forwardFFT( FFTW_MEASURE | FFTW_WISDOM_ONLY );
281 shifted_fft.forwardFFT( FFTW_MEASURE | FFTW_WISDOM_ONLY );
283 INFO(
"Comparing results." );
284 auto freq_image = fft.getFreqImage();
285 auto shifted_freq_image = shifted_fft.getFreqImage();
286 const TValue eps = 100 * std::numeric_limits<TValue>::epsilon();
288 for (
auto it = freq_image.cbegin(), shifted_it = shifted_freq_image.cbegin(); it != freq_image.cend(); ++it, ++shifted_it )
291 fft.calcScaledFreqValue( fft.calcScaledFreqCoords( it.getPoint() ), *it )
292 - shifted_fft.calcScaledFreqValue( shifted_fft.calcScaledFreqCoords( shifted_it.getPoint() ), *shifted_it ) )
293 > eps *
std::max( std::norm(*it), TValue(1) ) )
294 FAIL(
"Approximation failed at point " << it.getPoint()
295 <<
" between " << *it
296 <<
" and " << shifted_fft.calcScaledFreqValue( shifted_fft.calcScaledFreqCoords( shifted_it.getPoint() ), *shifted_it )
297 <<
" (scaled from " << *shifted_it <<
")" );
304 #ifdef WITH_FFTW3_FLOAT
305 TEST_CASE(
"Checking RealFFT on a 2D image in float precision.",
"[2D][float]" )
308 const std::string file_name = testPath +
"/samples/church-small.pgm";
315 INFO(
"Importing image " );
318 INFO(
"Testing forward and backward FFT." );
321 INFO(
"Testing spatial domain scaling." );
324 INFO(
"Testing FFT on translated image." );
329 #ifdef WITH_FFTW3_DOUBLE
330 TEST_CASE(
"Checking RealFFT on a 2D image in double precision.",
"[2D][double]" )
333 const std::string file_name = testPath +
"/samples/church-small.pgm";
340 INFO(
"Importing image " );
343 INFO(
"Testing forward and backward FFT." );
346 INFO(
"Testing spatial domain scaling." );
349 INFO(
"Testing FFT on translated image." );
354 #ifdef WITH_FFTW3_LONG
355 TEST_CASE(
"Checking RealFFT on a 2D image in long double precision.",
"[2D][long double]" )
357 using namespace DGtal;
360 const std::string file_name = testPath +
"/samples/church-small.pgm";
362 using real =
long double;
367 INFO(
"Importing image " );
370 INFO(
"Testing forward and backward FFT." );
373 INFO(
"Testing spatial domain scaling." );
376 INFO(
"Testing FFT on translated image." );
381 #ifdef WITH_FFTW3_DOUBLE
382 TEST_CASE(
"Checking RealFFT on a 3D image in double precision.",
"[3D][double]" )
394 auto const extent = image.extent();
396 INFO(
"Filling the image randomly." );
397 const std::size_t CNT = image.size() / 100;
398 std::random_device rd;
399 std::mt19937 gen( rd() );
400 std::uniform_real_distribution<> dis{};
403 for ( std::size_t i = 0; i < CNT; ++i )
407 image.setValue( pt, 1. );
410 INFO(
"Testing forward and backward FFT." );
413 INFO(
"Testing spatial domain scaling." );
416 INFO(
"Testing FFT on translated image." );
421 #ifdef WITH_FFTW3_DOUBLE
422 TEST_CASE(
"Checking RealFFT on a 4D image in double precision.",
"[4D][double]" )
434 auto const extent = image.extent();
436 INFO(
"Filling the image randomly." );
437 const std::size_t CNT = image.size() / 100;
438 std::random_device rd;
439 std::mt19937 gen( rd() );
440 std::uniform_real_distribution<> dis{};
443 for ( std::size_t i = 0; i < CNT; ++i )
447 image.setValue( pt, 1. );
450 INFO(
"Testing forward and backward FFT." );
453 INFO(
"Testing spatial domain scaling." );
456 INFO(
"Testing FFT on translated image." );
const Point & lowerBound() const
const Point & upperBound() const
const Domain & domain() const
const Vector & extent() const
Aim: implements association bewteen points lying in a digital domain and values.
DGtal is the top-level namespace which contains all DGtal functions and types.
DGtal::uint32_t Dimension
Aim: Import a 2D or 3D using the Netpbm formats (ASCII mode).
void testForwardBackwardFFT(ImageContainerBySTLVector< TDomain, TValue > const &anImage)
TIterator findMaxNorm(TIterator it, TIterator const &it_end)
void cmpTranslatedFFT(ImageContainerBySTLVector< TDomain, TValue > const &anImage)
void testFFTScaling(ImageContainerBySTLVector< TDomain, TValue > const &anImage)
TEST_CASE("Checking RealFFT on a 2D image in float precision.", "[2D][float]")
REQUIRE(domain.isInside(aPoint))
PointVector< 3, double > RealPoint