Path: blob/main/contrib/llvm-project/libcxx/include/__random/discrete_distribution.h
35233 views
//===----------------------------------------------------------------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//78#ifndef _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H9#define _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H1011#include <__algorithm/upper_bound.h>12#include <__config>13#include <__random/is_valid.h>14#include <__random/uniform_real_distribution.h>15#include <cstddef>16#include <iosfwd>17#include <numeric>18#include <vector>1920#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)21# pragma GCC system_header22#endif2324_LIBCPP_PUSH_MACROS25#include <__undef_macros>2627_LIBCPP_BEGIN_NAMESPACE_STD2829template <class _IntType = int>30class _LIBCPP_TEMPLATE_VIS discrete_distribution {31static_assert(__libcpp_random_is_valid_inttype<_IntType>::value, "IntType must be a supported integer type");3233public:34// types35typedef _IntType result_type;3637class _LIBCPP_TEMPLATE_VIS param_type {38vector<double> __p_;3940public:41typedef discrete_distribution distribution_type;4243_LIBCPP_HIDE_FROM_ABI param_type() {}44template <class _InputIterator>45_LIBCPP_HIDE_FROM_ABI param_type(_InputIterator __f, _InputIterator __l) : __p_(__f, __l) {46__init();47}48#ifndef _LIBCPP_CXX03_LANG49_LIBCPP_HIDE_FROM_ABI param_type(initializer_list<double> __wl) : __p_(__wl.begin(), __wl.end()) { __init(); }50#endif // _LIBCPP_CXX03_LANG51template <class _UnaryOperation>52_LIBCPP_HIDE_FROM_ABI param_type(size_t __nw, double __xmin, double __xmax, _UnaryOperation __fw);5354_LIBCPP_HIDE_FROM_ABI vector<double> probabilities() const;5556friend _LIBCPP_HIDE_FROM_ABI bool operator==(const param_type& __x, const param_type& __y) {57return __x.__p_ == __y.__p_;58}59friend _LIBCPP_HIDE_FROM_ABI bool operator!=(const param_type& __x, const param_type& __y) { return !(__x == __y); }6061private:62_LIBCPP_HIDE_FROM_ABI void __init();6364friend class discrete_distribution;6566template <class _CharT, class _Traits, class _IT>67friend basic_ostream<_CharT, _Traits>&68operator<<(basic_ostream<_CharT, _Traits>& __os, const discrete_distribution<_IT>& __x);6970template <class _CharT, class _Traits, class _IT>71friend basic_istream<_CharT, _Traits>&72operator>>(basic_istream<_CharT, _Traits>& __is, discrete_distribution<_IT>& __x);73};7475private:76param_type __p_;7778public:79// constructor and reset functions80_LIBCPP_HIDE_FROM_ABI discrete_distribution() {}81template <class _InputIterator>82_LIBCPP_HIDE_FROM_ABI discrete_distribution(_InputIterator __f, _InputIterator __l) : __p_(__f, __l) {}83#ifndef _LIBCPP_CXX03_LANG84_LIBCPP_HIDE_FROM_ABI discrete_distribution(initializer_list<double> __wl) : __p_(__wl) {}85#endif // _LIBCPP_CXX03_LANG86template <class _UnaryOperation>87_LIBCPP_HIDE_FROM_ABI discrete_distribution(size_t __nw, double __xmin, double __xmax, _UnaryOperation __fw)88: __p_(__nw, __xmin, __xmax, __fw) {}89_LIBCPP_HIDE_FROM_ABI explicit discrete_distribution(const param_type& __p) : __p_(__p) {}90_LIBCPP_HIDE_FROM_ABI void reset() {}9192// generating functions93template <class _URNG>94_LIBCPP_HIDE_FROM_ABI result_type operator()(_URNG& __g) {95return (*this)(__g, __p_);96}97template <class _URNG>98_LIBCPP_HIDE_FROM_ABI result_type operator()(_URNG& __g, const param_type& __p);99100// property functions101_LIBCPP_HIDE_FROM_ABI vector<double> probabilities() const { return __p_.probabilities(); }102103_LIBCPP_HIDE_FROM_ABI param_type param() const { return __p_; }104_LIBCPP_HIDE_FROM_ABI void param(const param_type& __p) { __p_ = __p; }105106_LIBCPP_HIDE_FROM_ABI result_type min() const { return 0; }107_LIBCPP_HIDE_FROM_ABI result_type max() const { return __p_.__p_.size(); }108109friend _LIBCPP_HIDE_FROM_ABI bool operator==(const discrete_distribution& __x, const discrete_distribution& __y) {110return __x.__p_ == __y.__p_;111}112friend _LIBCPP_HIDE_FROM_ABI bool operator!=(const discrete_distribution& __x, const discrete_distribution& __y) {113return !(__x == __y);114}115116template <class _CharT, class _Traits, class _IT>117friend basic_ostream<_CharT, _Traits>&118operator<<(basic_ostream<_CharT, _Traits>& __os, const discrete_distribution<_IT>& __x);119120template <class _CharT, class _Traits, class _IT>121friend basic_istream<_CharT, _Traits>&122operator>>(basic_istream<_CharT, _Traits>& __is, discrete_distribution<_IT>& __x);123};124125template <class _IntType>126template <class _UnaryOperation>127discrete_distribution<_IntType>::param_type::param_type(128size_t __nw, double __xmin, double __xmax, _UnaryOperation __fw) {129if (__nw > 1) {130__p_.reserve(__nw - 1);131double __d = (__xmax - __xmin) / __nw;132double __d2 = __d / 2;133for (size_t __k = 0; __k < __nw; ++__k)134__p_.push_back(__fw(__xmin + __k * __d + __d2));135__init();136}137}138139template <class _IntType>140void discrete_distribution<_IntType>::param_type::__init() {141if (!__p_.empty()) {142if (__p_.size() > 1) {143double __s = std::accumulate(__p_.begin(), __p_.end(), 0.0);144for (vector<double>::iterator __i = __p_.begin(), __e = __p_.end(); __i < __e; ++__i)145*__i /= __s;146vector<double> __t(__p_.size() - 1);147std::partial_sum(__p_.begin(), __p_.end() - 1, __t.begin());148swap(__p_, __t);149} else {150__p_.clear();151__p_.shrink_to_fit();152}153}154}155156template <class _IntType>157vector<double> discrete_distribution<_IntType>::param_type::probabilities() const {158size_t __n = __p_.size();159vector<double> __p(__n + 1);160std::adjacent_difference(__p_.begin(), __p_.end(), __p.begin());161if (__n > 0)162__p[__n] = 1 - __p_[__n - 1];163else164__p[0] = 1;165return __p;166}167168template <class _IntType>169template <class _URNG>170_IntType discrete_distribution<_IntType>::operator()(_URNG& __g, const param_type& __p) {171static_assert(__libcpp_random_is_valid_urng<_URNG>::value, "");172uniform_real_distribution<double> __gen;173return static_cast<_IntType>(std::upper_bound(__p.__p_.begin(), __p.__p_.end(), __gen(__g)) - __p.__p_.begin());174}175176template <class _CharT, class _Traits, class _IT>177_LIBCPP_HIDE_FROM_ABI basic_ostream<_CharT, _Traits>&178operator<<(basic_ostream<_CharT, _Traits>& __os, const discrete_distribution<_IT>& __x) {179__save_flags<_CharT, _Traits> __lx(__os);180typedef basic_ostream<_CharT, _Traits> _OStream;181__os.flags(_OStream::dec | _OStream::left | _OStream::fixed | _OStream::scientific);182_CharT __sp = __os.widen(' ');183__os.fill(__sp);184size_t __n = __x.__p_.__p_.size();185__os << __n;186for (size_t __i = 0; __i < __n; ++__i)187__os << __sp << __x.__p_.__p_[__i];188return __os;189}190191template <class _CharT, class _Traits, class _IT>192_LIBCPP_HIDE_FROM_ABI basic_istream<_CharT, _Traits>&193operator>>(basic_istream<_CharT, _Traits>& __is, discrete_distribution<_IT>& __x) {194__save_flags<_CharT, _Traits> __lx(__is);195typedef basic_istream<_CharT, _Traits> _Istream;196__is.flags(_Istream::dec | _Istream::skipws);197size_t __n;198__is >> __n;199vector<double> __p(__n);200for (size_t __i = 0; __i < __n; ++__i)201__is >> __p[__i];202if (!__is.fail())203swap(__x.__p_.__p_, __p);204return __is;205}206207_LIBCPP_END_NAMESPACE_STD208209_LIBCPP_POP_MACROS210211#endif // _LIBCPP___RANDOM_DISCRETE_DISTRIBUTION_H212213214