// Fast normal random number generation // Copyright snsinfu 2018. // Distributed under the Boost Software License, Version 1.0. // // Permission is hereby granted, free of charge, to any person or organization // obtaining a copy of the software and accompanying documentation covered by // this license (the "Software") to use, reproduce, display, distribute, // execute, and transmit the Software, and to prepare derivative works of the // Software, and to permit third-parties to whom the Software is furnished to // do so, all subject to the following: // // The copyright notices in the Software and this entire statement, including // the above license grant, this restriction and the following disclaimer, // must be included in all copies of the Software, in whole or in part, and // all derivative works of the Software, unless such copies or derivative // works are solely in the form of machine-executable object code generated by // a source language processor. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT // SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE // FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, // ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. #ifndef INCLUDED_ZIGGURAT_HPP #define INCLUDED_ZIGGURAT_HPP #include <cmath> #include <cstddef> #include <cstdint> #include <ios> #include <istream> #include <limits> #include <ostream> #include <random> #if defined(__GNUC__) # define ZIGGURAT_LIKELY(x) __builtin_expect((x), 1) # define ZIGGURAT_NOINLINE __attribute__((noinline)) #else # define ZIGGURAT_LIKELY(x) (x) # define ZIGGURAT_NOINLINE #endif namespace cxx { namespace ziggurat_detail { // is_pow2m1 checks if num + 1 is a power of two. template<typename T> inline constexpr bool is_pow2m1(T num) { return (num & (num + 1)) == 0; } // log2 computes the base-2 logarithm of num truncated to integer. inline constexpr std::size_t log2(std::uint64_t num) { return num / 2 ? 1 + log2(num / 2) : 0; } // generate_bits draws N random bits from given random number generator. template<std::size_t N, typename URNG> inline std::uint64_t generate_bits(URNG& random) { constexpr std::uint64_t mask = (std::uint64_t(1) << N) - 1; if (URNG::min() == 0 && URNG::max() >= mask && is_pow2m1(URNG::max())) { return std::uint64_t(random()) & mask; } else { std::uniform_int_distribution<std::uint64_t> dist(0, mask); return dist(random); } } // canonicalize transforms N bits into a floating-point number in [0, 1). template<std::size_t N, typename T> inline T canonicalize(std::uint64_t bits) { constexpr int real_bits = std::numeric_limits<T>::digits; constexpr int uint_bits = N; constexpr int data_bits = (real_bits < uint_bits ? real_bits : uint_bits); constexpr T norm = 1 / T(std::int64_t(1) << data_bits); return norm * T(bits >> (uint_bits - data_bits)); } // gaussian returns exp(-x^2/2). template<typename T> inline T gaussian(T x) { return std::exp(T(-0.5) * x * x); } // normal_ziggurat holds a pre-computed ziggurat table. template<typename T> struct normal_ziggurat { static T const edges[0x81]; }; } // ziggurat_normal_distribution generates normal random numbers using the fast // ziggurat algorithm. template<typename T> class ziggurat_normal_distribution { // Pull in the ziggurat table to use. using ziggurat = ziggurat_detail::normal_ziggurat<T>; public: // result_type is an alias of T. using result_type = T; // param_type holds distribution parameters. struct param_type { using distribution_type = ziggurat_normal_distribution; // Default constructor initializes mean to 0 and stddev to 1. param_type() = default; // Single-parameter constructor initializes mean and stddev to given // values. explicit param_type(result_type mean, result_type stddev = 1) : mean_{mean}, stddev_{stddev} { } // mean returns the mean parameter. inline result_type mean() const { return mean_; } // stddev returns the stddev parameter. inline result_type stddev() const { return stddev_; } // Equality comparison p1 == p2 returns true if and only if mean and // stddev parameters, respectively, are the same for p1 and p2. friend bool operator==(param_type const& p1, param_type const& p2) { return p1.mean_ == p2.mean_ && p1.stddev_ == p2.stddev_; } friend bool operator!=(param_type const& p1, param_type const& p2) { return !(p1 == p2); } // Stream output write mean and stddev to a stream. template<typename Char, typename Tr> friend std::basic_ostream<Char, Tr>& operator<<( std::basic_ostream<Char, Tr>& os, param_type const& param ) { using sentry_type = typename std::basic_ostream<Char, Tr>::sentry; // TODO: need to normalize stream flags? if (sentry_type sentry{os}) { Char const space = os.widen(' '); os << param.mean_ << space << param.stddev_; } return os; } // Stream input reads mean and stddev from a stream. template<typename Char, typename Tr> friend std::basic_istream<Char, Tr>& operator>>( std::basic_istream<Char, Tr>& is, param_type& param ) { using sentry_type = typename std::basic_istream<Char, Tr>::sentry; // TODO: need to normalize stream flags? if (sentry_type sentry{is}) { param_type tmp; if (is >> tmp.mean_ >> tmp.stddev_) { param = tmp; } } return is; } private: result_type mean_ = 0; result_type stddev_ = 1; }; // Default constructor creates a normal distribution with mean = 0 and // stddev = 1. ziggurat_normal_distribution() = default; // This constructor creates a normal distribution with given mean and // stddev. explicit ziggurat_normal_distribution(result_type mean, result_type stddev = 1) : param_{mean, stddev} { } // This constructor creates a normal distribution having given // parameters. explicit ziggurat_normal_distribution(param_type const& param) : param_{param} { } // reset does nothing; this is a RandomNumberDistribution requirement. void reset() { } // Invoking a distribution with a random number engine returns a newly // generated normal random number with the preconfigured parameters. template<typename URNG> inline T operator()(URNG& random) { return param_.mean() + param_.stddev() * sample(random); } // Invoking a distribution with a random number engine and a parameter // object returns a newly generated normal random number with given // parameters. template<typename URNG> inline T operator()(URNG& random, param_type const& param) { return param.mean() + param.stddev() * sample(random); } // mean returns the mean parameter of this distribution. result_type mean() const { return param_.mean(); } // stddev returns the stddev parameter of this distribution. result_type stddev() const { return param_.stddev(); } // param returns the parameters of this distribution as a param_type. param_type param() const { return param_; } // param sets the parameters of this distribution. void param(param_type const& param) { param_ = param; } // min returns -infinity. result_type min() const { return -std::numeric_limits<result_type>::infinity(); } // max returns +infinity. result_type max() const { return std::numeric_limits<result_type>::infinity(); } private: // sample generates a standard normal number. template<typename URNG> inline T sample(URNG& random) const { constexpr std::size_t bit_count = ziggurat_detail::log2(URNG::max() - URNG::min()); for (;;) { auto const bits = ziggurat_detail::generate_bits<bit_count>(random); auto const uniform = ziggurat_detail::canonicalize<bit_count, T>(bits); auto const layer = std::size_t(bits & 0x7F); auto const sign = T((bits & 0x80) ? 1 : -1); auto const lower_edge = ziggurat::edges[layer]; auto const upper_edge = ziggurat::edges[layer + 1]; auto const x = uniform * lower_edge; if (ZIGGURAT_LIKELY(x < upper_edge)) { return sign * x; } if (layer == 0) { return sign * sample_from_tail(random); } if (check_accept(random, lower_edge, upper_edge, x)) { return sign * x; } } } template<typename URNG> ZIGGURAT_NOINLINE T sample_from_tail(URNG& random) const { T const tail_edge = ziggurat::edges[1]; std::uniform_real_distribution<T> uniform; T x, y; do { x = -std::log(uniform(random)) / tail_edge; y = -std::log(uniform(random)); } while (2 * y < x * x); return tail_edge + x; } template<typename URNG> ZIGGURAT_NOINLINE bool check_accept(URNG& random, T lower_edge, T upper_edge, T x) const { // Rejection sampling from the interval [upper_edge, lower_edge]. std::uniform_real_distribution<T> uniform( ziggurat_detail::gaussian(lower_edge), ziggurat_detail::gaussian(upper_edge) ); return uniform(random) < ziggurat_detail::gaussian(x); } private: param_type param_; }; // Equality comparison d1 == d2 compares the equality of distribution // parameters. template<typename T> bool operator==( ziggurat_normal_distribution<T> const& d1, ziggurat_normal_distribution<T> const& d2 ) { return d1.param() == d2.param(); } template<typename T> bool operator!=( ziggurat_normal_distribution<T> const& d1, ziggurat_normal_distribution<T> const& d2 ) { return !(d1 == d2); } // Stream output operator writes mean and stddev parameters to a stream. template<typename Char, typename Tr, typename T> std::basic_ostream<Char, Tr>& operator<<( std::basic_ostream<Char, Tr>& os, ziggurat_normal_distribution<T> const& dist ) { return os << dist.param(); } // Stream input operator reads mean and stddev parameters from a stream. template<typename Char, typename Tr, typename T> std::basic_istream<Char, Tr>& operator>>( std::basic_istream<Char, Tr>& is, ziggurat_normal_distribution<T>& dist ) { typename ziggurat_normal_distribution<T>::param_type param; if (is >> param) { dist.param(param); } return is; } // Pre-computed ziggurat table. template<typename T> T const ziggurat_detail::normal_ziggurat<T>::edges[] = { T(3.71308624674036292), T(3.44261985589665231), T(3.22308498457861869), T(3.083228858214214), T(2.97869625264501714), T(2.89434400701867078), T(2.82312535054596658), T(2.76116937238415394), T(2.70611357311872291), T(2.65640641125819288), T(2.61097224842861353), T(2.56903362592163953), T(2.53000967238546703), T(2.49345452209195129), T(2.4590181774083506), T(2.42642064553021219), T(2.395434278007468), T(2.36587137011398818), T(2.3375752413355313), T(2.31041368369500244), T(2.28427405967365704), T(2.25905957386533007), T(2.23468639558705728), T(2.21108140887472837), T(2.18818043207202084), T(2.16592679374484121), T(2.14427018235626177), T(2.1231657086697906), T(2.10257313518499966), T(2.08245623798772517), T(2.0627822745039639), T(2.04352153665067027), T(2.02464697337293442), T(2.00613386995896725), T(1.98795957412306135), T(1.97010326084971399), T(1.9525457295488895), T(1.93526922829190084), T(1.91825730085973256), T(1.90149465310031829), T(1.88496703570286983), T(1.86866114098954261), T(1.85256451172308778), T(1.83666546025338473), T(1.82095299659100562), T(1.80541676421404929), T(1.79004698259461947), T(1.77483439558076972), T(1.7597702248942324), T(1.74484612810837714), T(1.73005416055824424), T(1.71538674070811714), T(1.70083661856430157), T(1.68639684677348689), T(1.67206075409185284), T(1.65782192094820813), T(1.64367415685698326), T(1.62961147946467899), T(1.61562809503713356), T(1.6017183802152779), T(1.5878768648844015), T(1.57409821601675048), T(1.56037722235984133), T(1.54670877985350419), T(1.53308787766755672), T(1.51950958475937159), T(1.50596903685655104), T(1.49246142377461632), T(1.47898197698309875), T(1.46552595733579549), T(1.45208864288221728), T(1.43866531667746189), T(1.4252512545068623), T(1.41184171243976109), T(1.39843191412360701), T(1.38501703772514939), T(1.3715922024197329), T(1.35815245432242371), T(1.3446927517457139), T(1.33120794965767741), T(1.31769278320134386), T(1.30414185012042227), T(1.29054959191787399), T(1.27691027355170061), T(1.26321796144602927), T(1.24946649956433475), T(1.23564948325448198), T(1.22176023053096339), T(1.20779175040675857), T(1.19373670782377306), T(1.17958738465446178), T(1.16533563615504776), T(1.1509728421389771), T(1.1364898520030764), T(1.12187692257225491), T(1.1071236475235362), T(1.09221887689655461), T(1.07715062488193869), T(1.06190596368362034), T(1.0464709007525812), T(1.03083023605645652), T(1.01496739523930057), T(0.998864233480644681), T(0.982500803502761477), T(0.965855079388131865), T(0.948902625497913155), T(0.931616196601354973), T(0.913965251008802881), T(0.895915352566239664), T(0.877427429097716982), T(0.858456843178052043), T(0.838952214281208697), T(0.818853906683319033), T(0.798092060626276134), T(0.776583987876149906), T(0.754230664434511699), T(0.730911910621882877), T(0.706479611313609812), T(0.680747918645906114), T(0.653478638715044413), T(0.62435859730909038), T(0.592962942441980445), T(0.558692178375520654), T(0.520656038725148096), T(0.477437837253791464), T(0.426547986303309479), T(0.362871431028424229), T(0.272320864704672982), T(8.56006539842194211e-08) }; } #undef ZIGGURAT_LIKELY #undef ZIGGURAT_NOINLINE #endif