comparison DEPENDENCIES/generic/include/boost/random/discrete_distribution.hpp @ 101:c530137014c0

Update Boost headers (1.58.0)
author Chris Cannam
date Mon, 07 Sep 2015 11:12:49 +0100
parents 2665513ce2d3
children
comparison
equal deleted inserted replaced
100:793467b5e61c 101:c530137014c0
5 * accompanying file LICENSE_1_0.txt or copy at 5 * accompanying file LICENSE_1_0.txt or copy at
6 * http://www.boost.org/LICENSE_1_0.txt) 6 * http://www.boost.org/LICENSE_1_0.txt)
7 * 7 *
8 * See http://www.boost.org for most recent version including documentation. 8 * See http://www.boost.org for most recent version including documentation.
9 * 9 *
10 * $Id: discrete_distribution.hpp 85813 2013-09-21 20:17:00Z jewillco $ 10 * $Id$
11 */ 11 */
12 12
13 #ifndef BOOST_RANDOM_DISCRETE_DISTRIBUTION_HPP_INCLUDED 13 #ifndef BOOST_RANDOM_DISCRETE_DISTRIBUTION_HPP_INCLUDED
14 #define BOOST_RANDOM_DISCRETE_DISTRIBUTION_HPP_INCLUDED 14 #define BOOST_RANDOM_DISCRETE_DISTRIBUTION_HPP_INCLUDED
15 15
18 #include <numeric> 18 #include <numeric>
19 #include <utility> 19 #include <utility>
20 #include <iterator> 20 #include <iterator>
21 #include <boost/assert.hpp> 21 #include <boost/assert.hpp>
22 #include <boost/random/uniform_01.hpp> 22 #include <boost/random/uniform_01.hpp>
23 #include <boost/random/uniform_int.hpp> 23 #include <boost/random/uniform_int_distribution.hpp>
24 #include <boost/random/detail/config.hpp> 24 #include <boost/random/detail/config.hpp>
25 #include <boost/random/detail/operators.hpp> 25 #include <boost/random/detail/operators.hpp>
26 #include <boost/random/detail/vector_io.hpp> 26 #include <boost/random/detail/vector_io.hpp>
27 27
28 #ifndef BOOST_NO_CXX11_HDR_INITIALIZER_LIST 28 #ifndef BOOST_NO_CXX11_HDR_INITIALIZER_LIST
34 34
35 #include <boost/random/detail/disable_warnings.hpp> 35 #include <boost/random/detail/disable_warnings.hpp>
36 36
37 namespace boost { 37 namespace boost {
38 namespace random { 38 namespace random {
39 namespace detail {
40
41 template<class IntType, class WeightType>
42 struct integer_alias_table {
43 WeightType get_weight(IntType bin) const {
44 WeightType result = _average;
45 if(bin < _excess) ++result;
46 return result;
47 }
48 template<class Iter>
49 WeightType init_average(Iter begin, Iter end) {
50 WeightType weight_average = 0;
51 IntType excess = 0;
52 IntType n = 0;
53 // weight_average * n + excess == current partial sum
54 // This is a bit messy, but it's guaranteed not to overflow
55 for(Iter iter = begin; iter != end; ++iter) {
56 ++n;
57 if(*iter < weight_average) {
58 WeightType diff = weight_average - *iter;
59 weight_average -= diff / n;
60 if(diff % n > excess) {
61 --weight_average;
62 excess += n - diff % n;
63 } else {
64 excess -= diff % n;
65 }
66 } else {
67 WeightType diff = *iter - weight_average;
68 weight_average += diff / n;
69 if(diff % n < n - excess) {
70 excess += diff % n;
71 } else {
72 ++weight_average;
73 excess -= n - diff % n;
74 }
75 }
76 }
77 _alias_table.resize(static_cast<std::size_t>(n));
78 _average = weight_average;
79 _excess = excess;
80 return weight_average;
81 }
82 void init_empty()
83 {
84 _alias_table.clear();
85 _alias_table.push_back(std::make_pair(static_cast<WeightType>(1),
86 static_cast<IntType>(0)));
87 _average = static_cast<WeightType>(1);
88 _excess = static_cast<IntType>(0);
89 }
90 bool operator==(const integer_alias_table& other) const
91 {
92 return _alias_table == other._alias_table &&
93 _average == other._average && _excess == other._excess;
94 }
95 static WeightType normalize(WeightType val, WeightType average)
96 {
97 return val;
98 }
99 static void normalize(std::vector<WeightType>&) {}
100 template<class URNG>
101 WeightType test(URNG &urng) const
102 {
103 return uniform_int_distribution<WeightType>(0, _average)(urng);
104 }
105 bool accept(IntType result, WeightType val) const
106 {
107 return result < _excess || val < _average;
108 }
109 static WeightType try_get_sum(const std::vector<WeightType>& weights)
110 {
111 WeightType result = static_cast<WeightType>(0);
112 for(typename std::vector<WeightType>::const_iterator
113 iter = weights.begin(), end = weights.end();
114 iter != end; ++iter)
115 {
116 if((std::numeric_limits<WeightType>::max)() - result > *iter) {
117 return static_cast<WeightType>(0);
118 }
119 result += *iter;
120 }
121 return result;
122 }
123 template<class URNG>
124 static WeightType generate_in_range(URNG &urng, WeightType max)
125 {
126 return uniform_int_distribution<WeightType>(
127 static_cast<WeightType>(0), max-1)(urng);
128 }
129 typedef std::vector<std::pair<WeightType, IntType> > alias_table_t;
130 alias_table_t _alias_table;
131 WeightType _average;
132 IntType _excess;
133 };
134
135 template<class IntType, class WeightType>
136 struct real_alias_table {
137 WeightType get_weight(IntType) const
138 {
139 return WeightType(1.0);
140 }
141 template<class Iter>
142 WeightType init_average(Iter first, Iter last)
143 {
144 std::size_t size = std::distance(first, last);
145 WeightType weight_sum =
146 std::accumulate(first, last, static_cast<WeightType>(0));
147 _alias_table.resize(size);
148 return weight_sum / size;
149 }
150 void init_empty()
151 {
152 _alias_table.clear();
153 _alias_table.push_back(std::make_pair(static_cast<WeightType>(1),
154 static_cast<IntType>(0)));
155 }
156 bool operator==(const real_alias_table& other) const
157 {
158 return _alias_table == other._alias_table;
159 }
160 static WeightType normalize(WeightType val, WeightType average)
161 {
162 return val / average;
163 }
164 static void normalize(std::vector<WeightType>& weights)
165 {
166 WeightType sum =
167 std::accumulate(weights.begin(), weights.end(),
168 static_cast<WeightType>(0));
169 for(typename std::vector<WeightType>::iterator
170 iter = weights.begin(),
171 end = weights.end();
172 iter != end; ++iter)
173 {
174 *iter /= sum;
175 }
176 }
177 template<class URNG>
178 WeightType test(URNG &urng) const
179 {
180 return uniform_01<WeightType>()(urng);
181 }
182 bool accept(IntType, WeightType) const
183 {
184 return true;
185 }
186 static WeightType try_get_sum(const std::vector<WeightType>& weights)
187 {
188 return static_cast<WeightType>(1);
189 }
190 template<class URNG>
191 static WeightType generate_in_range(URNG &urng, WeightType)
192 {
193 return uniform_01<WeightType>()(urng);
194 }
195 typedef std::vector<std::pair<WeightType, IntType> > alias_table_t;
196 alias_table_t _alias_table;
197 };
198
199 template<bool IsIntegral>
200 struct select_alias_table;
201
202 template<>
203 struct select_alias_table<true> {
204 template<class IntType, class WeightType>
205 struct apply {
206 typedef integer_alias_table<IntType, WeightType> type;
207 };
208 };
209
210 template<>
211 struct select_alias_table<false> {
212 template<class IntType, class WeightType>
213 struct apply {
214 typedef real_alias_table<IntType, WeightType> type;
215 };
216 };
217
218 }
39 219
40 /** 220 /**
41 * The class @c discrete_distribution models a \random_distribution. 221 * The class @c discrete_distribution models a \random_distribution.
42 * It produces integers in the range [0, n) with the probability 222 * It produces integers in the range [0, n) with the probability
43 * of producing each value is specified by the parameters of the 223 * of producing each value is specified by the parameters of the
153 explicit param_type(const discrete_distribution& dist) 333 explicit param_type(const discrete_distribution& dist)
154 : _probabilities(dist.probabilities()) 334 : _probabilities(dist.probabilities())
155 {} 335 {}
156 void normalize() 336 void normalize()
157 { 337 {
158 WeightType sum = 338 impl_type::normalize(_probabilities);
159 std::accumulate(_probabilities.begin(), _probabilities.end(),
160 static_cast<WeightType>(0));
161 for(typename std::vector<WeightType>::iterator
162 iter = _probabilities.begin(),
163 end = _probabilities.end();
164 iter != end; ++iter)
165 {
166 *iter /= sum;
167 }
168 } 339 }
169 std::vector<WeightType> _probabilities; 340 std::vector<WeightType> _probabilities;
170 /// @endcond 341 /// @endcond
171 }; 342 };
172 343
174 * Creates a new @c discrete_distribution object that has 345 * Creates a new @c discrete_distribution object that has
175 * \f$p(0) = 1\f$ and \f$p(i|i>0) = 0\f$. 346 * \f$p(0) = 1\f$ and \f$p(i|i>0) = 0\f$.
176 */ 347 */
177 discrete_distribution() 348 discrete_distribution()
178 { 349 {
179 _alias_table.push_back(std::make_pair(static_cast<WeightType>(1), 350 _impl.init_empty();
180 static_cast<IntType>(0)));
181 } 351 }
182 /** 352 /**
183 * Constructs a discrete_distribution from an iterator range. 353 * Constructs a discrete_distribution from an iterator range.
184 * If @c first == @c last, equivalent to the default constructor. 354 * If @c first == @c last, equivalent to the default constructor.
185 * Otherwise, the values of the range represent weights for the 355 * Otherwise, the values of the range represent weights for the
255 * discrete_distribution. 425 * discrete_distribution.
256 */ 426 */
257 template<class URNG> 427 template<class URNG>
258 IntType operator()(URNG& urng) const 428 IntType operator()(URNG& urng) const
259 { 429 {
260 BOOST_ASSERT(!_alias_table.empty()); 430 BOOST_ASSERT(!_impl._alias_table.empty());
261 WeightType test = uniform_01<WeightType>()(urng); 431 IntType result;
262 IntType result = uniform_int<IntType>((min)(), (max)())(urng); 432 WeightType test;
263 if(test < _alias_table[result].first) { 433 do {
434 result = uniform_int_distribution<IntType>((min)(), (max)())(urng);
435 test = _impl.test(urng);
436 } while(!_impl.accept(result, test));
437 if(test < _impl._alias_table[result].first) {
264 return result; 438 return result;
265 } else { 439 } else {
266 return(_alias_table[result].second); 440 return(_impl._alias_table[result].second);
267 } 441 }
268 } 442 }
269 443
270 /** 444 /**
271 * Returns a value distributed according to the parameters 445 * Returns a value distributed according to the parameters
272 * specified by param. 446 * specified by param.
273 */ 447 */
274 template<class URNG> 448 template<class URNG>
275 IntType operator()(URNG& urng, const param_type& parm) const 449 IntType operator()(URNG& urng, const param_type& parm) const
276 { 450 {
277 while(true) { 451 if(WeightType limit = impl_type::try_get_sum(parm._probabilities)) {
278 WeightType val = uniform_01<WeightType>()(urng); 452 WeightType val = impl_type::generate_in_range(urng, limit);
279 WeightType sum = 0; 453 WeightType sum = 0;
280 std::size_t result = 0; 454 std::size_t result = 0;
281 for(typename std::vector<WeightType>::const_iterator 455 for(typename std::vector<WeightType>::const_iterator
282 iter = parm._probabilities.begin(), 456 iter = parm._probabilities.begin(),
283 end = parm._probabilities.end(); 457 end = parm._probabilities.end();
284 iter != end; ++iter, ++result) 458 iter != end; ++iter, ++result)
285 { 459 {
286 sum += *iter; 460 sum += *iter;
287 if(sum > val) { 461 if(sum > val) {
288 return result; 462 return result;
289 } 463 }
290 } 464 }
465 // This shouldn't be reachable, but round-off error
466 // can prevent any match from being found when val is
467 // very close to 1.
468 return static_cast<IntType>(parm._probabilities.size() - 1);
469 } else {
470 // WeightType is integral and sum(parm._probabilities)
471 // would overflow. Just use the easy solution.
472 return discrete_distribution(parm)(urng);
291 } 473 }
292 } 474 }
293 475
294 /** Returns the smallest value that the distribution can produce. */ 476 /** Returns the smallest value that the distribution can produce. */
295 result_type min BOOST_PREVENT_MACRO_SUBSTITUTION () const { return 0; } 477 result_type min BOOST_PREVENT_MACRO_SUBSTITUTION () const { return 0; }
296 /** Returns the largest value that the distribution can produce. */ 478 /** Returns the largest value that the distribution can produce. */
297 result_type max BOOST_PREVENT_MACRO_SUBSTITUTION () const 479 result_type max BOOST_PREVENT_MACRO_SUBSTITUTION () const
298 { return static_cast<result_type>(_alias_table.size() - 1); } 480 { return static_cast<result_type>(_impl._alias_table.size() - 1); }
299 481
300 /** 482 /**
301 * Returns a vector containing the probabilities of each 483 * Returns a vector containing the probabilities of each
302 * value of the distribution. For example, given 484 * value of the distribution. For example, given
303 * 485 *
305 * discrete_distribution<> dist = { 1, 4, 5 }; 487 * discrete_distribution<> dist = { 1, 4, 5 };
306 * std::vector<double> p = dist.param(); 488 * std::vector<double> p = dist.param();
307 * @endcode 489 * @endcode
308 * 490 *
309 * the vector, p will contain {0.1, 0.4, 0.5}. 491 * the vector, p will contain {0.1, 0.4, 0.5}.
492 *
493 * If @c WeightType is integral, then the weights
494 * will be returned unchanged.
310 */ 495 */
311 std::vector<WeightType> probabilities() const 496 std::vector<WeightType> probabilities() const
312 { 497 {
313 std::vector<WeightType> result(_alias_table.size()); 498 std::vector<WeightType> result(_impl._alias_table.size());
314 const WeightType mean =
315 static_cast<WeightType>(1) / _alias_table.size();
316 std::size_t i = 0; 499 std::size_t i = 0;
317 for(typename alias_table_t::const_iterator 500 for(typename impl_type::alias_table_t::const_iterator
318 iter = _alias_table.begin(), 501 iter = _impl._alias_table.begin(),
319 end = _alias_table.end(); 502 end = _impl._alias_table.end();
320 iter != end; ++iter, ++i) 503 iter != end; ++iter, ++i)
321 { 504 {
322 WeightType val = iter->first * mean; 505 WeightType val = iter->first;
323 result[i] += val; 506 result[i] += val;
324 result[iter->second] += mean - val; 507 result[iter->second] += _impl.get_weight(i) - val;
325 } 508 }
509 impl_type::normalize(result);
326 return(result); 510 return(result);
327 } 511 }
328 512
329 /** Returns the parameters of the distribution. */ 513 /** Returns the parameters of the distribution. */
330 param_type param() const 514 param_type param() const
364 * Returns true if the two distributions will return the 548 * Returns true if the two distributions will return the
365 * same sequence of values, when passed equal generators. 549 * same sequence of values, when passed equal generators.
366 */ 550 */
367 BOOST_RANDOM_DETAIL_EQUALITY_OPERATOR(discrete_distribution, lhs, rhs) 551 BOOST_RANDOM_DETAIL_EQUALITY_OPERATOR(discrete_distribution, lhs, rhs)
368 { 552 {
369 return lhs._alias_table == rhs._alias_table; 553 return lhs._impl == rhs._impl;
370 } 554 }
371 /** 555 /**
372 * Returns true if the two distributions may return different 556 * Returns true if the two distributions may return different
373 * sequences of values, when passed equal generators. 557 * sequences of values, when passed equal generators.
374 */ 558 */
387 template<class Iter> 571 template<class Iter>
388 void init(Iter first, Iter last, std::forward_iterator_tag) 572 void init(Iter first, Iter last, std::forward_iterator_tag)
389 { 573 {
390 std::vector<std::pair<WeightType, IntType> > below_average; 574 std::vector<std::pair<WeightType, IntType> > below_average;
391 std::vector<std::pair<WeightType, IntType> > above_average; 575 std::vector<std::pair<WeightType, IntType> > above_average;
392 std::size_t size = std::distance(first, last); 576 WeightType weight_average = _impl.init_average(first, last);
393 WeightType weight_sum = 577 WeightType normalized_average = _impl.get_weight(0);
394 std::accumulate(first, last, static_cast<WeightType>(0));
395 WeightType weight_average = weight_sum / size;
396 std::size_t i = 0; 578 std::size_t i = 0;
397 for(; first != last; ++first, ++i) { 579 for(; first != last; ++first, ++i) {
398 WeightType val = *first / weight_average; 580 WeightType val = impl_type::normalize(*first, weight_average);
399 std::pair<WeightType, IntType> elem(val, static_cast<IntType>(i)); 581 std::pair<WeightType, IntType> elem(val, static_cast<IntType>(i));
400 if(val < static_cast<WeightType>(1)) { 582 if(val < normalized_average) {
401 below_average.push_back(elem); 583 below_average.push_back(elem);
402 } else { 584 } else {
403 above_average.push_back(elem); 585 above_average.push_back(elem);
404 } 586 }
405 } 587 }
406 588
407 _alias_table.resize(size); 589 typename impl_type::alias_table_t::iterator
408 typename alias_table_t::iterator
409 b_iter = below_average.begin(), 590 b_iter = below_average.begin(),
410 b_end = below_average.end(), 591 b_end = below_average.end(),
411 a_iter = above_average.begin(), 592 a_iter = above_average.begin(),
412 a_end = above_average.end() 593 a_end = above_average.end()
413 ; 594 ;
414 while(b_iter != b_end && a_iter != a_end) { 595 while(b_iter != b_end && a_iter != a_end) {
415 _alias_table[b_iter->second] = 596 _impl._alias_table[b_iter->second] =
416 std::make_pair(b_iter->first, a_iter->second); 597 std::make_pair(b_iter->first, a_iter->second);
417 a_iter->first -= (static_cast<WeightType>(1) - b_iter->first); 598 a_iter->first -= (_impl.get_weight(b_iter->second) - b_iter->first);
418 if(a_iter->first < static_cast<WeightType>(1)) { 599 if(a_iter->first < normalized_average) {
419 *b_iter = *a_iter++; 600 *b_iter = *a_iter++;
420 } else { 601 } else {
421 ++b_iter; 602 ++b_iter;
422 } 603 }
423 } 604 }
424 for(; b_iter != b_end; ++b_iter) { 605 for(; b_iter != b_end; ++b_iter) {
425 _alias_table[b_iter->second].first = static_cast<WeightType>(1); 606 _impl._alias_table[b_iter->second].first =
607 _impl.get_weight(b_iter->second);
426 } 608 }
427 for(; a_iter != a_end; ++a_iter) { 609 for(; a_iter != a_end; ++a_iter) {
428 _alias_table[a_iter->second].first = static_cast<WeightType>(1); 610 _impl._alias_table[a_iter->second].first =
611 _impl.get_weight(a_iter->second);
429 } 612 }
430 } 613 }
431 template<class Iter> 614 template<class Iter>
432 void init(Iter first, Iter last) 615 void init(Iter first, Iter last)
433 { 616 {
434 if(first == last) { 617 if(first == last) {
435 _alias_table.clear(); 618 _impl.init_empty();
436 _alias_table.push_back(std::make_pair(static_cast<WeightType>(1),
437 static_cast<IntType>(0)));
438 } else { 619 } else {
439 typename std::iterator_traits<Iter>::iterator_category category; 620 typename std::iterator_traits<Iter>::iterator_category category;
440 init(first, last, category); 621 init(first, last, category);
441 } 622 }
442 } 623 }
443 typedef std::vector<std::pair<WeightType, IntType> > alias_table_t; 624 typedef typename detail::select_alias_table<
444 alias_table_t _alias_table; 625 (::boost::is_integral<WeightType>::value)
626 >::template apply<IntType, WeightType>::type impl_type;
627 impl_type _impl;
445 /// @endcond 628 /// @endcond
446 }; 629 };
447 630
448 } 631 }
449 } 632 }