1 // Copyright 2017 The Abseil Authors.
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // https://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
15 #ifndef ABSL_RANDOM_LOG_UNIFORM_INT_DISTRIBUTION_H_
16 #define ABSL_RANDOM_LOG_UNIFORM_INT_DISTRIBUTION_H_
24 #include <type_traits>
26 #include "absl/random/internal/distribution_impl.h"
27 #include "absl/random/internal/fastmath.h"
28 #include "absl/random/internal/iostream_state_saver.h"
29 #include "absl/random/internal/traits.h"
30 #include "absl/random/uniform_int_distribution.h"
34 // log_uniform_int_distribution:
36 // Returns a random variate R in range [min, max] such that
37 // floor(log(R-min, base)) is uniformly distributed.
38 // We ensure uniformity by discretization using the
39 // boundary sets [0, 1, base, base * base, ... min(base*n, max)]
41 template <typename IntType = int>
42 class log_uniform_int_distribution {
45 typename random_internal::make_unsigned_bits<IntType>::type;
48 using result_type = IntType;
52 using distribution_type = log_uniform_int_distribution;
56 result_type max = (std::numeric_limits<result_type>::max)(),
61 range_(static_cast<unsigned_type>(max_) -
62 static_cast<unsigned_type>(min_)),
68 // Determine where the first set bit is on range(), giving a log2(range)
69 // value which can be used to construct bounds.
70 log_range_ = (std::min)(random_internal::LeadingSetBit(range()),
71 std::numeric_limits<unsigned_type>::digits);
73 // NOTE: Computing the logN(x) introduces error from 2 sources:
74 // 1. Conversion of int to double loses precision for values >=
75 // 2^53, which may cause some log() computations to operate on
77 // 2. The error introduced by the division will cause the result
78 // to differ from the expected value.
80 // Thus a result which should equal K may equal K +/- epsilon,
81 // which can eliminate some values depending on where the bounds fall.
82 const double inv_log_base = 1.0 / std::log(base_);
83 const double log_range = std::log(static_cast<double>(range()) + 0.5);
84 log_range_ = static_cast<int>(std::ceil(inv_log_base * log_range));
88 result_type(min)() const { return min_; }
89 result_type(max)() const { return max_; }
90 result_type base() const { return base_; }
92 friend bool operator==(const param_type& a, const param_type& b) {
93 return a.min_ == b.min_ && a.max_ == b.max_ && a.base_ == b.base_;
96 friend bool operator!=(const param_type& a, const param_type& b) {
101 friend class log_uniform_int_distribution;
103 int log_range() const { return log_range_; }
104 unsigned_type range() const { return range_; }
109 unsigned_type range_; // max - min
110 int log_range_; // ceil(logN(range_))
112 static_assert(std::is_integral<IntType>::value,
113 "Class-template absl::log_uniform_int_distribution<> must be "
114 "parameterized using an integral type.");
117 log_uniform_int_distribution() : log_uniform_int_distribution(0) {}
119 explicit log_uniform_int_distribution(
121 result_type max = (std::numeric_limits<result_type>::max)(),
122 result_type base = 2)
123 : param_(min, max, base) {}
125 explicit log_uniform_int_distribution(const param_type& p) : param_(p) {}
129 // generating functions
130 template <typename URBG>
131 result_type operator()(URBG& g) { // NOLINT(runtime/references)
132 return (*this)(g, param_);
135 template <typename URBG>
136 result_type operator()(URBG& g, // NOLINT(runtime/references)
137 const param_type& p) {
138 return (p.min)() + Generate(g, p);
141 result_type(min)() const { return (param_.min)(); }
142 result_type(max)() const { return (param_.max)(); }
143 result_type base() const { return param_.base(); }
145 param_type param() const { return param_; }
146 void param(const param_type& p) { param_ = p; }
148 friend bool operator==(const log_uniform_int_distribution& a,
149 const log_uniform_int_distribution& b) {
150 return a.param_ == b.param_;
152 friend bool operator!=(const log_uniform_int_distribution& a,
153 const log_uniform_int_distribution& b) {
154 return a.param_ != b.param_;
158 // Returns a log-uniform variate in the range [0, p.range()]. The caller
159 // should add min() to shift the result to the correct range.
160 template <typename URNG>
161 unsigned_type Generate(URNG& g, // NOLINT(runtime/references)
162 const param_type& p);
167 template <typename IntType>
168 template <typename URBG>
169 typename log_uniform_int_distribution<IntType>::unsigned_type
170 log_uniform_int_distribution<IntType>::Generate(
171 URBG& g, // NOLINT(runtime/references)
172 const param_type& p) {
173 // sample e over [0, log_range]. Map the results of e to this:
177 // n => [b^(n-1)..(b^n)-1]
178 const int e = absl::uniform_int_distribution<int>(0, p.log_range())(g);
184 unsigned_type base_e, top_e;
186 base_e = static_cast<unsigned_type>(1) << d;
188 top_e = (e >= std::numeric_limits<unsigned_type>::digits)
189 ? (std::numeric_limits<unsigned_type>::max)()
190 : (static_cast<unsigned_type>(1) << e) - 1;
192 const double r = std::pow(p.base(), d);
193 const double s = (r * p.base()) - 1.0;
195 base_e = (r > (std::numeric_limits<unsigned_type>::max)())
196 ? (std::numeric_limits<unsigned_type>::max)()
197 : static_cast<unsigned_type>(r);
199 top_e = (s > (std::numeric_limits<unsigned_type>::max)())
200 ? (std::numeric_limits<unsigned_type>::max)()
201 : static_cast<unsigned_type>(s);
204 const unsigned_type lo = (base_e >= p.range()) ? p.range() : base_e;
205 const unsigned_type hi = (top_e >= p.range()) ? p.range() : top_e;
207 // choose uniformly over [lo, hi]
208 return absl::uniform_int_distribution<result_type>(lo, hi)(g);
211 template <typename CharT, typename Traits, typename IntType>
212 std::basic_ostream<CharT, Traits>& operator<<(
213 std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references)
214 const log_uniform_int_distribution<IntType>& x) {
216 typename random_internal::stream_format_type<IntType>::type;
217 auto saver = random_internal::make_ostream_state_saver(os);
218 os << static_cast<stream_type>((x.min)()) << os.fill()
219 << static_cast<stream_type>((x.max)()) << os.fill()
220 << static_cast<stream_type>(x.base());
224 template <typename CharT, typename Traits, typename IntType>
225 std::basic_istream<CharT, Traits>& operator>>(
226 std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references)
227 log_uniform_int_distribution<IntType>& x) { // NOLINT(runtime/references)
228 using param_type = typename log_uniform_int_distribution<IntType>::param_type;
230 typename log_uniform_int_distribution<IntType>::result_type;
232 typename random_internal::stream_format_type<IntType>::type;
238 auto saver = random_internal::make_istream_state_saver(is);
239 is >> min >> max >> base;
241 x.param(param_type(static_cast<result_type>(min),
242 static_cast<result_type>(max),
243 static_cast<result_type>(base)));
250 #endif // ABSL_RANDOM_LOG_UNIFORM_INT_DISTRIBUTION_H_