Path: blob/main/contrib/llvm-project/libcxx/include/__random/gamma_distribution.h
35233 views
//===----------------------------------------------------------------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//78#ifndef _LIBCPP___RANDOM_GAMMA_DISTRIBUTION_H9#define _LIBCPP___RANDOM_GAMMA_DISTRIBUTION_H1011#include <__config>12#include <__random/exponential_distribution.h>13#include <__random/is_valid.h>14#include <__random/uniform_real_distribution.h>15#include <cmath>16#include <iosfwd>17#include <limits>1819#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)20# pragma GCC system_header21#endif2223_LIBCPP_PUSH_MACROS24#include <__undef_macros>2526_LIBCPP_BEGIN_NAMESPACE_STD2728template <class _RealType = double>29class _LIBCPP_TEMPLATE_VIS gamma_distribution {30static_assert(__libcpp_random_is_valid_realtype<_RealType>::value,31"RealType must be a supported floating-point type");3233public:34// types35typedef _RealType result_type;3637class _LIBCPP_TEMPLATE_VIS param_type {38result_type __alpha_;39result_type __beta_;4041public:42typedef gamma_distribution distribution_type;4344_LIBCPP_HIDE_FROM_ABI explicit param_type(result_type __alpha = 1, result_type __beta = 1)45: __alpha_(__alpha), __beta_(__beta) {}4647_LIBCPP_HIDE_FROM_ABI result_type alpha() const { return __alpha_; }48_LIBCPP_HIDE_FROM_ABI result_type beta() const { return __beta_; }4950friend _LIBCPP_HIDE_FROM_ABI bool operator==(const param_type& __x, const param_type& __y) {51return __x.__alpha_ == __y.__alpha_ && __x.__beta_ == __y.__beta_;52}53friend _LIBCPP_HIDE_FROM_ABI bool operator!=(const param_type& __x, const param_type& __y) { return !(__x == __y); }54};5556private:57param_type __p_;5859public:60// constructors and reset functions61#ifndef _LIBCPP_CXX03_LANG62_LIBCPP_HIDE_FROM_ABI gamma_distribution() : gamma_distribution(1) {}63_LIBCPP_HIDE_FROM_ABI explicit gamma_distribution(result_type __alpha, result_type __beta = 1)64: __p_(param_type(__alpha, __beta)) {}65#else66_LIBCPP_HIDE_FROM_ABI explicit gamma_distribution(result_type __alpha = 1, result_type __beta = 1)67: __p_(param_type(__alpha, __beta)) {}68#endif69_LIBCPP_HIDE_FROM_ABI explicit gamma_distribution(const param_type& __p) : __p_(__p) {}70_LIBCPP_HIDE_FROM_ABI void reset() {}7172// generating functions73template <class _URNG>74_LIBCPP_HIDE_FROM_ABI result_type operator()(_URNG& __g) {75return (*this)(__g, __p_);76}77template <class _URNG>78_LIBCPP_HIDE_FROM_ABI result_type operator()(_URNG& __g, const param_type& __p);7980// property functions81_LIBCPP_HIDE_FROM_ABI result_type alpha() const { return __p_.alpha(); }82_LIBCPP_HIDE_FROM_ABI result_type beta() const { return __p_.beta(); }8384_LIBCPP_HIDE_FROM_ABI param_type param() const { return __p_; }85_LIBCPP_HIDE_FROM_ABI void param(const param_type& __p) { __p_ = __p; }8687_LIBCPP_HIDE_FROM_ABI result_type min() const { return 0; }88_LIBCPP_HIDE_FROM_ABI result_type max() const { return numeric_limits<result_type>::infinity(); }8990friend _LIBCPP_HIDE_FROM_ABI bool operator==(const gamma_distribution& __x, const gamma_distribution& __y) {91return __x.__p_ == __y.__p_;92}93friend _LIBCPP_HIDE_FROM_ABI bool operator!=(const gamma_distribution& __x, const gamma_distribution& __y) {94return !(__x == __y);95}96};9798template <class _RealType>99template <class _URNG>100_RealType gamma_distribution<_RealType>::operator()(_URNG& __g, const param_type& __p) {101static_assert(__libcpp_random_is_valid_urng<_URNG>::value, "");102result_type __a = __p.alpha();103uniform_real_distribution<result_type> __gen(0, 1);104exponential_distribution<result_type> __egen;105result_type __x;106if (__a == 1)107__x = __egen(__g);108else if (__a > 1) {109const result_type __b = __a - 1;110const result_type __c = 3 * __a - result_type(0.75);111while (true) {112const result_type __u = __gen(__g);113const result_type __v = __gen(__g);114const result_type __w = __u * (1 - __u);115if (__w != 0) {116const result_type __y = std::sqrt(__c / __w) * (__u - result_type(0.5));117__x = __b + __y;118if (__x >= 0) {119const result_type __z = 64 * __w * __w * __w * __v * __v;120if (__z <= 1 - 2 * __y * __y / __x)121break;122if (std::log(__z) <= 2 * (__b * std::log(__x / __b) - __y))123break;124}125}126}127} else // __a < 1128{129while (true) {130const result_type __u = __gen(__g);131const result_type __es = __egen(__g);132if (__u <= 1 - __a) {133__x = std::pow(__u, 1 / __a);134if (__x <= __es)135break;136} else {137const result_type __e = -std::log((1 - __u) / __a);138__x = std::pow(1 - __a + __a * __e, 1 / __a);139if (__x <= __e + __es)140break;141}142}143}144return __x * __p.beta();145}146147template <class _CharT, class _Traits, class _RT>148_LIBCPP_HIDE_FROM_ABI basic_ostream<_CharT, _Traits>&149operator<<(basic_ostream<_CharT, _Traits>& __os, const gamma_distribution<_RT>& __x) {150__save_flags<_CharT, _Traits> __lx(__os);151typedef basic_ostream<_CharT, _Traits> _OStream;152__os.flags(_OStream::dec | _OStream::left | _OStream::fixed | _OStream::scientific);153_CharT __sp = __os.widen(' ');154__os.fill(__sp);155__os << __x.alpha() << __sp << __x.beta();156return __os;157}158159template <class _CharT, class _Traits, class _RT>160_LIBCPP_HIDE_FROM_ABI basic_istream<_CharT, _Traits>&161operator>>(basic_istream<_CharT, _Traits>& __is, gamma_distribution<_RT>& __x) {162typedef gamma_distribution<_RT> _Eng;163typedef typename _Eng::result_type result_type;164typedef typename _Eng::param_type param_type;165__save_flags<_CharT, _Traits> __lx(__is);166typedef basic_istream<_CharT, _Traits> _Istream;167__is.flags(_Istream::dec | _Istream::skipws);168result_type __alpha;169result_type __beta;170__is >> __alpha >> __beta;171if (!__is.fail())172__x.param(param_type(__alpha, __beta));173return __is;174}175176_LIBCPP_END_NAMESPACE_STD177178_LIBCPP_POP_MACROS179180#endif // _LIBCPP___RANDOM_GAMMA_DISTRIBUTION_H181182183