Path: blob/master/thirdparty/basis_universal/encoder/basisu_math.h
9903 views
// File: basisu_math.h1#pragma once23// TODO: Would prefer this in the basisu namespace, but to avoid collisions with the existing vec/matrix classes I'm placing this in "bu_math".4namespace bu_math5{6// Cross-platform 1.0f/sqrtf(x) approximation. See https://en.wikipedia.org/wiki/Fast_inverse_square_root#cite_note-37.7// Would prefer using SSE1 etc. but that would require implementing multiple versions and platform divergence (needing more testing).8BASISU_FORCE_INLINE float inv_sqrt(float v)9{10union11{12float flt;13uint32_t ui;14} un;1516un.flt = v;17un.ui = 0x5F1FFFF9UL - (un.ui >> 1);1819return 0.703952253f * un.flt * (2.38924456f - v * (un.flt * un.flt));20}2122inline float smoothstep(float edge0, float edge1, float x)23{24assert(edge1 != edge0);2526// Scale, and clamp x to 0..1 range27x = basisu::saturate((x - edge0) / (edge1 - edge0));2829return x * x * (3.0f - 2.0f * x);30}3132template <uint32_t N, typename T>33class vec : public basisu::rel_ops<vec<N, T> >34{35public:36typedef T scalar_type;37enum38{39num_elements = N40};4142inline vec()43{44}4546inline vec(basisu::eClear)47{48clear();49}5051inline vec(const vec& other)52{53for (uint32_t i = 0; i < N; i++)54m_s[i] = other.m_s[i];55}5657template <uint32_t O, typename U>58inline vec(const vec<O, U>& other)59{60set(other);61}6263template <uint32_t O, typename U>64inline vec(const vec<O, U>& other, T w)65{66*this = other;67m_s[N - 1] = w;68}6970template <typename... Args>71inline explicit vec(Args... args)72{73static_assert(sizeof...(args) <= N);74set(args...);75}7677inline void clear()78{79if (N > 4)80memset(m_s, 0, sizeof(m_s));81else82{83for (uint32_t i = 0; i < N; i++)84m_s[i] = 0;85}86}8788template <uint32_t ON, typename OT>89inline vec& set(const vec<ON, OT>& other)90{91if ((void*)this == (void*)&other)92return *this;93const uint32_t m = basisu::minimum(N, ON);94uint32_t i;95for (i = 0; i < m; i++)96m_s[i] = static_cast<T>(other[i]);97for (; i < N; i++)98m_s[i] = 0;99return *this;100}101102inline vec& set_component(uint32_t index, T val)103{104assert(index < N);105m_s[index] = val;106return *this;107}108109inline vec& set_all(T val)110{111for (uint32_t i = 0; i < N; i++)112m_s[i] = val;113return *this;114}115116template <typename... Args>117inline vec& set(Args... args)118{119static_assert(sizeof...(args) <= N);120121// Initialize using parameter pack expansion122T values[] = { static_cast<T>(args)... };123124// Special case if setting with a scalar125if (sizeof...(args) == 1)126{127set_all(values[0]);128}129else130{131// Copy the values into the vector132for (std::size_t i = 0; i < sizeof...(args); ++i)133{134m_s[i] = values[i];135}136137// Zero-initialize the remaining elements (if any)138if (sizeof...(args) < N)139{140std::fill(m_s + sizeof...(args), m_s + N, T{});141}142}143144return *this;145}146147inline vec& set(const T* pValues)148{149for (uint32_t i = 0; i < N; i++)150m_s[i] = pValues[i];151return *this;152}153154template <uint32_t ON, typename OT>155inline vec& swizzle_set(const vec<ON, OT>& other, uint32_t i)156{157return set(static_cast<T>(other[i]));158}159160template <uint32_t ON, typename OT>161inline vec& swizzle_set(const vec<ON, OT>& other, uint32_t i, uint32_t j)162{163return set(static_cast<T>(other[i]), static_cast<T>(other[j]));164}165166template <uint32_t ON, typename OT>167inline vec& swizzle_set(const vec<ON, OT>& other, uint32_t i, uint32_t j, uint32_t k)168{169return set(static_cast<T>(other[i]), static_cast<T>(other[j]), static_cast<T>(other[k]));170}171172template <uint32_t ON, typename OT>173inline vec& swizzle_set(const vec<ON, OT>& other, uint32_t i, uint32_t j, uint32_t k, uint32_t l)174{175return set(static_cast<T>(other[i]), static_cast<T>(other[j]), static_cast<T>(other[k]), static_cast<T>(other[l]));176}177178inline vec& operator=(const vec& rhs)179{180if (this != &rhs)181{182for (uint32_t i = 0; i < N; i++)183m_s[i] = rhs.m_s[i];184}185return *this;186}187188template <uint32_t O, typename U>189inline vec& operator=(const vec<O, U>& other)190{191if ((void*)this == (void*)&other)192return *this;193194uint32_t s = basisu::minimum(N, O);195196uint32_t i;197for (i = 0; i < s; i++)198m_s[i] = static_cast<T>(other[i]);199200for (; i < N; i++)201m_s[i] = 0;202203return *this;204}205206inline bool operator==(const vec& rhs) const207{208for (uint32_t i = 0; i < N; i++)209if (!(m_s[i] == rhs.m_s[i]))210return false;211return true;212}213214inline bool operator<(const vec& rhs) const215{216for (uint32_t i = 0; i < N; i++)217{218if (m_s[i] < rhs.m_s[i])219return true;220else if (!(m_s[i] == rhs.m_s[i]))221return false;222}223224return false;225}226227inline T operator[](uint32_t i) const228{229assert(i < N);230return m_s[i];231}232233inline T& operator[](uint32_t i)234{235assert(i < N);236return m_s[i];237}238239template <uint32_t index>240inline uint64_t get_component_bits_as_uint() const241{242static_assert(index < N);243static_assert((sizeof(T) == sizeof(uint16_t)) || (sizeof(T) == sizeof(uint32_t)) || (sizeof(T) == sizeof(uint64_t)), "Unsupported type");244245if (sizeof(T) == sizeof(uint16_t))246return *reinterpret_cast<const uint16_t*>(&m_s[index]);247else if (sizeof(T) == sizeof(uint32_t))248return *reinterpret_cast<const uint32_t*>(&m_s[index]);249else if (sizeof(T) == sizeof(uint64_t))250return *reinterpret_cast<const uint64_t*>(&m_s[index]);251else252{253assert(0);254return 0;255}256}257258inline T get_x(void) const259{260return m_s[0];261}262inline T get_y(void) const263{264static_assert(N >= 2);265return m_s[1];266}267inline T get_z(void) const268{269static_assert(N >= 3);270return m_s[2];271}272inline T get_w(void) const273{274static_assert(N >= 4);275return m_s[3];276}277278inline vec get_x_vector() const279{280return broadcast<0>();281}282inline vec get_y_vector() const283{284return broadcast<1>();285}286inline vec get_z_vector() const287{288return broadcast<2>();289}290inline vec get_w_vector() const291{292return broadcast<3>();293}294295inline T get_component(uint32_t i) const296{297return (*this)[i];298}299300inline vec& set_x(T v)301{302m_s[0] = v;303return *this;304}305inline vec& set_y(T v)306{307static_assert(N >= 2);308m_s[1] = v;309return *this;310}311inline vec& set_z(T v)312{313static_assert(N >= 3);314m_s[2] = v;315return *this;316}317inline vec& set_w(T v)318{319static_assert(N >= 4);320m_s[3] = v;321return *this;322}323324inline const T* get_ptr() const325{326return reinterpret_cast<const T*>(&m_s[0]);327}328inline T* get_ptr()329{330return reinterpret_cast<T*>(&m_s[0]);331}332333inline vec as_point() const334{335vec result(*this);336result[N - 1] = 1;337return result;338}339340inline vec as_dir() const341{342vec result(*this);343result[N - 1] = 0;344return result;345}346347inline vec<2, T> select2(uint32_t i, uint32_t j) const348{349assert((i < N) && (j < N));350return vec<2, T>(m_s[i], m_s[j]);351}352353inline vec<3, T> select3(uint32_t i, uint32_t j, uint32_t k) const354{355assert((i < N) && (j < N) && (k < N));356return vec<3, T>(m_s[i], m_s[j], m_s[k]);357}358359inline vec<4, T> select4(uint32_t i, uint32_t j, uint32_t k, uint32_t l) const360{361assert((i < N) && (j < N) && (k < N) && (l < N));362return vec<4, T>(m_s[i], m_s[j], m_s[k], m_s[l]);363}364365inline bool is_dir() const366{367return m_s[N - 1] == 0;368}369inline bool is_vector() const370{371return is_dir();372}373inline bool is_point() const374{375return m_s[N - 1] == 1;376}377378inline vec project() const379{380vec result(*this);381if (result[N - 1])382result /= result[N - 1];383return result;384}385386inline vec broadcast(unsigned i) const387{388return vec((*this)[i]);389}390391template <uint32_t i>392inline vec broadcast() const393{394return vec((*this)[i]);395}396397inline vec swizzle(uint32_t i, uint32_t j) const398{399return vec((*this)[i], (*this)[j]);400}401402inline vec swizzle(uint32_t i, uint32_t j, uint32_t k) const403{404return vec((*this)[i], (*this)[j], (*this)[k]);405}406407inline vec swizzle(uint32_t i, uint32_t j, uint32_t k, uint32_t l) const408{409return vec((*this)[i], (*this)[j], (*this)[k], (*this)[l]);410}411412inline vec operator-() const413{414vec result;415for (uint32_t i = 0; i < N; i++)416result.m_s[i] = -m_s[i];417return result;418}419420inline vec operator+() const421{422return *this;423}424425inline vec& operator+=(const vec& other)426{427for (uint32_t i = 0; i < N; i++)428m_s[i] += other.m_s[i];429return *this;430}431432inline vec& operator-=(const vec& other)433{434for (uint32_t i = 0; i < N; i++)435m_s[i] -= other.m_s[i];436return *this;437}438439inline vec& operator*=(const vec& other)440{441for (uint32_t i = 0; i < N; i++)442m_s[i] *= other.m_s[i];443return *this;444}445446inline vec& operator/=(const vec& other)447{448for (uint32_t i = 0; i < N; i++)449m_s[i] /= other.m_s[i];450return *this;451}452453inline vec& operator*=(T s)454{455for (uint32_t i = 0; i < N; i++)456m_s[i] *= s;457return *this;458}459460inline vec& operator/=(T s)461{462for (uint32_t i = 0; i < N; i++)463m_s[i] /= s;464return *this;465}466467friend inline vec operator*(const vec& lhs, T val)468{469vec result;470for (uint32_t i = 0; i < N; i++)471result.m_s[i] = lhs.m_s[i] * val;472return result;473}474475friend inline vec operator*(T val, const vec& rhs)476{477vec result;478for (uint32_t i = 0; i < N; i++)479result.m_s[i] = val * rhs.m_s[i];480return result;481}482483friend inline vec operator/(const vec& lhs, const vec& rhs)484{485vec result;486for (uint32_t i = 0; i < N; i++)487result.m_s[i] = lhs.m_s[i] / rhs.m_s[i];488return result;489}490491friend inline vec operator/(const vec& lhs, T val)492{493vec result;494for (uint32_t i = 0; i < N; i++)495result.m_s[i] = lhs.m_s[i] / val;496return result;497}498499friend inline vec operator+(const vec& lhs, const vec& rhs)500{501vec result;502for (uint32_t i = 0; i < N; i++)503result.m_s[i] = lhs.m_s[i] + rhs.m_s[i];504return result;505}506507friend inline vec operator-(const vec& lhs, const vec& rhs)508{509vec result;510for (uint32_t i = 0; i < N; i++)511result.m_s[i] = lhs.m_s[i] - rhs.m_s[i];512return result;513}514515static inline vec<3, T> cross2(const vec& a, const vec& b)516{517static_assert(N >= 2);518return vec<3, T>(0, 0, a[0] * b[1] - a[1] * b[0]);519}520521inline vec<3, T> cross2(const vec& b) const522{523return cross2(*this, b);524}525526static inline vec<3, T> cross3(const vec& a, const vec& b)527{528static_assert(N >= 3);529return vec<3, T>(a[1] * b[2] - a[2] * b[1], a[2] * b[0] - a[0] * b[2], a[0] * b[1] - a[1] * b[0]);530}531532inline vec<3, T> cross3(const vec& b) const533{534return cross3(*this, b);535}536537static inline vec<3, T> cross(const vec& a, const vec& b)538{539static_assert(N >= 2);540541if (N == 2)542return cross2(a, b);543else544return cross3(a, b);545}546547inline vec<3, T> cross(const vec& b) const548{549static_assert(N >= 2);550return cross(*this, b);551}552553inline T dot(const vec& rhs) const554{555return dot(*this, rhs);556}557558inline vec dot_vector(const vec& rhs) const559{560return vec(dot(*this, rhs));561}562563static inline T dot(const vec& lhs, const vec& rhs)564{565T result = lhs.m_s[0] * rhs.m_s[0];566for (uint32_t i = 1; i < N; i++)567result += lhs.m_s[i] * rhs.m_s[i];568return result;569}570571inline T dot2(const vec& rhs) const572{573static_assert(N >= 2);574return m_s[0] * rhs.m_s[0] + m_s[1] * rhs.m_s[1];575}576577inline T dot3(const vec& rhs) const578{579static_assert(N >= 3);580return m_s[0] * rhs.m_s[0] + m_s[1] * rhs.m_s[1] + m_s[2] * rhs.m_s[2];581}582583inline T dot4(const vec& rhs) const584{585static_assert(N >= 4);586return m_s[0] * rhs.m_s[0] + m_s[1] * rhs.m_s[1] + m_s[2] * rhs.m_s[2] + m_s[3] * rhs.m_s[3];587}588589inline T norm(void) const590{591T sum = m_s[0] * m_s[0];592for (uint32_t i = 1; i < N; i++)593sum += m_s[i] * m_s[i];594return sum;595}596597inline T length(void) const598{599return sqrt(norm());600}601602inline T squared_distance(const vec& rhs) const603{604T dist2 = 0;605for (uint32_t i = 0; i < N; i++)606{607T d = m_s[i] - rhs.m_s[i];608dist2 += d * d;609}610return dist2;611}612613inline T squared_distance(const vec& rhs, T early_out) const614{615T dist2 = 0;616for (uint32_t i = 0; i < N; i++)617{618T d = m_s[i] - rhs.m_s[i];619dist2 += d * d;620if (dist2 > early_out)621break;622}623return dist2;624}625626inline T distance(const vec& rhs) const627{628T dist2 = 0;629for (uint32_t i = 0; i < N; i++)630{631T d = m_s[i] - rhs.m_s[i];632dist2 += d * d;633}634return sqrt(dist2);635}636637inline vec inverse() const638{639vec result;640for (uint32_t i = 0; i < N; i++)641result[i] = m_s[i] ? (1.0f / m_s[i]) : 0;642return result;643}644645// returns squared length (norm)646inline double normalize(const vec* pDefaultVec = NULL)647{648double n = m_s[0] * m_s[0];649for (uint32_t i = 1; i < N; i++)650n += m_s[i] * m_s[i];651652if (n != 0)653*this *= static_cast<T>(1.0f / sqrt(n));654else if (pDefaultVec)655*this = *pDefaultVec;656return n;657}658659inline double normalize3(const vec* pDefaultVec = NULL)660{661static_assert(N >= 3);662663double n = m_s[0] * m_s[0] + m_s[1] * m_s[1] + m_s[2] * m_s[2];664665if (n != 0)666*this *= static_cast<T>((1.0f / sqrt(n)));667else if (pDefaultVec)668*this = *pDefaultVec;669return n;670}671672inline vec& normalize_in_place(const vec* pDefaultVec = NULL)673{674normalize(pDefaultVec);675return *this;676}677678inline vec& normalize3_in_place(const vec* pDefaultVec = NULL)679{680normalize3(pDefaultVec);681return *this;682}683684inline vec get_normalized(const vec* pDefaultVec = NULL) const685{686vec result(*this);687result.normalize(pDefaultVec);688return result;689}690691inline vec get_normalized3(const vec* pDefaultVec = NULL) const692{693vec result(*this);694result.normalize3(pDefaultVec);695return result;696}697698inline vec& clamp(T l, T h)699{700for (uint32_t i = 0; i < N; i++)701m_s[i] = static_cast<T>(basisu::clamp(m_s[i], l, h));702return *this;703}704705inline vec& saturate()706{707return clamp(0.0f, 1.0f);708}709710inline vec& clamp(const vec& l, const vec& h)711{712for (uint32_t i = 0; i < N; i++)713m_s[i] = static_cast<T>(basisu::clamp(m_s[i], l[i], h[i]));714return *this;715}716717inline bool is_within_bounds(const vec& l, const vec& h) const718{719for (uint32_t i = 0; i < N; i++)720if ((m_s[i] < l[i]) || (m_s[i] > h[i]))721return false;722723return true;724}725726inline bool is_within_bounds(T l, T h) const727{728for (uint32_t i = 0; i < N; i++)729if ((m_s[i] < l) || (m_s[i] > h))730return false;731732return true;733}734735inline uint32_t get_major_axis(void) const736{737T m = fabs(m_s[0]);738uint32_t r = 0;739for (uint32_t i = 1; i < N; i++)740{741const T c = fabs(m_s[i]);742if (c > m)743{744m = c;745r = i;746}747}748return r;749}750751inline uint32_t get_minor_axis(void) const752{753T m = fabs(m_s[0]);754uint32_t r = 0;755for (uint32_t i = 1; i < N; i++)756{757const T c = fabs(m_s[i]);758if (c < m)759{760m = c;761r = i;762}763}764return r;765}766767inline void get_projection_axes(uint32_t& u, uint32_t& v) const768{769const int axis = get_major_axis();770if (m_s[axis] < 0.0f)771{772v = basisu::next_wrap<uint32_t>(axis, N);773u = basisu::next_wrap<uint32_t>(v, N);774}775else776{777u = basisu::next_wrap<uint32_t>(axis, N);778v = basisu::next_wrap<uint32_t>(u, N);779}780}781782inline T get_absolute_minimum(void) const783{784T result = fabs(m_s[0]);785for (uint32_t i = 1; i < N; i++)786result = basisu::minimum(result, fabs(m_s[i]));787return result;788}789790inline T get_absolute_maximum(void) const791{792T result = fabs(m_s[0]);793for (uint32_t i = 1; i < N; i++)794result = basisu::maximum(result, fabs(m_s[i]));795return result;796}797798inline T get_minimum(void) const799{800T result = m_s[0];801for (uint32_t i = 1; i < N; i++)802result = basisu::minimum(result, m_s[i]);803return result;804}805806inline T get_maximum(void) const807{808T result = m_s[0];809for (uint32_t i = 1; i < N; i++)810result = basisu::maximum(result, m_s[i]);811return result;812}813814inline vec& remove_unit_direction(const vec& dir)815{816*this -= (dot(dir) * dir);817return *this;818}819820inline vec get_remove_unit_direction(const vec& dir) const821{822return *this - (dot(dir) * dir);823}824825inline bool all_less(const vec& b) const826{827for (uint32_t i = 0; i < N; i++)828if (m_s[i] >= b.m_s[i])829return false;830return true;831}832833inline bool all_less_equal(const vec& b) const834{835for (uint32_t i = 0; i < N; i++)836if (m_s[i] > b.m_s[i])837return false;838return true;839}840841inline bool all_greater(const vec& b) const842{843for (uint32_t i = 0; i < N; i++)844if (m_s[i] <= b.m_s[i])845return false;846return true;847}848849inline bool all_greater_equal(const vec& b) const850{851for (uint32_t i = 0; i < N; i++)852if (m_s[i] < b.m_s[i])853return false;854return true;855}856857inline vec negate_xyz() const858{859vec ret;860861ret[0] = -m_s[0];862if (N >= 2)863ret[1] = -m_s[1];864if (N >= 3)865ret[2] = -m_s[2];866867for (uint32_t i = 3; i < N; i++)868ret[i] = m_s[i];869870return ret;871}872873inline vec& invert()874{875for (uint32_t i = 0; i < N; i++)876if (m_s[i] != 0.0f)877m_s[i] = 1.0f / m_s[i];878return *this;879}880881inline scalar_type perp_dot(const vec& b) const882{883static_assert(N == 2);884return m_s[0] * b.m_s[1] - m_s[1] * b.m_s[0];885}886887inline vec perp() const888{889static_assert(N == 2);890return vec(-m_s[1], m_s[0]);891}892893inline vec get_floor() const894{895vec result;896for (uint32_t i = 0; i < N; i++)897result[i] = floor(m_s[i]);898return result;899}900901inline vec get_ceil() const902{903vec result;904for (uint32_t i = 0; i < N; i++)905result[i] = ceil(m_s[i]);906return result;907}908909inline T get_total() const910{911T res = m_s[0];912for (uint32_t i = 1; i < N; i++)913res += m_s[i];914return res;915}916917// static helper methods918919static inline vec mul_components(const vec& lhs, const vec& rhs)920{921vec result;922for (uint32_t i = 0; i < N; i++)923result[i] = lhs.m_s[i] * rhs.m_s[i];924return result;925}926927static inline vec mul_add_components(const vec& a, const vec& b, const vec& c)928{929vec result;930for (uint32_t i = 0; i < N; i++)931result[i] = a.m_s[i] * b.m_s[i] + c.m_s[i];932return result;933}934935static inline vec make_axis(uint32_t i)936{937vec result;938result.clear();939result[i] = 1;940return result;941}942943static inline vec equals_mask(const vec& a, const vec& b)944{945vec ret;946for (uint32_t i = 0; i < N; i++)947ret[i] = (a[i] == b[i]);948return ret;949}950951static inline vec not_equals_mask(const vec& a, const vec& b)952{953vec ret;954for (uint32_t i = 0; i < N; i++)955ret[i] = (a[i] != b[i]);956return ret;957}958959static inline vec less_mask(const vec& a, const vec& b)960{961vec ret;962for (uint32_t i = 0; i < N; i++)963ret[i] = (a[i] < b[i]);964return ret;965}966967static inline vec less_equals_mask(const vec& a, const vec& b)968{969vec ret;970for (uint32_t i = 0; i < N; i++)971ret[i] = (a[i] <= b[i]);972return ret;973}974975static inline vec greater_equals_mask(const vec& a, const vec& b)976{977vec ret;978for (uint32_t i = 0; i < N; i++)979ret[i] = (a[i] >= b[i]);980return ret;981}982983static inline vec greater_mask(const vec& a, const vec& b)984{985vec ret;986for (uint32_t i = 0; i < N; i++)987ret[i] = (a[i] > b[i]);988return ret;989}990991static inline vec component_max(const vec& a, const vec& b)992{993vec ret;994for (uint32_t i = 0; i < N; i++)995ret.m_s[i] = basisu::maximum(a.m_s[i], b.m_s[i]);996return ret;997}998999static inline vec component_min(const vec& a, const vec& b)1000{1001vec ret;1002for (uint32_t i = 0; i < N; i++)1003ret.m_s[i] = basisu::minimum(a.m_s[i], b.m_s[i]);1004return ret;1005}10061007static inline vec lerp(const vec& a, const vec& b, float t)1008{1009vec ret;1010for (uint32_t i = 0; i < N; i++)1011ret.m_s[i] = a.m_s[i] + (b.m_s[i] - a.m_s[i]) * t;1012return ret;1013}10141015static inline bool equal_tol(const vec& a, const vec& b, float t)1016{1017for (uint32_t i = 0; i < N; i++)1018if (!basisu::equal_tol(a.m_s[i], b.m_s[i], t))1019return false;1020return true;1021}10221023inline bool equal_tol(const vec& b, float t) const1024{1025return equal_tol(*this, b, t);1026}10271028static inline vec make_random(basisu::rand& r, float l, float h)1029{1030vec result;1031for (uint32_t i = 0; i < N; i++)1032result[i] = r.frand(l, h);1033return result;1034}10351036static inline vec make_random(basisu::rand& r, const vec& l, const vec& h)1037{1038vec result;1039for (uint32_t i = 0; i < N; i++)1040result[i] = r.frand(l[i], h[i]);1041return result;1042}10431044void print() const1045{1046for (uint32_t c = 0; c < N; c++)1047printf("%3.3f ", (*this)[c]);1048printf("\n");1049}10501051protected:1052T m_s[N];1053};10541055typedef vec<1, double> vec1D;1056typedef vec<2, double> vec2D;1057typedef vec<3, double> vec3D;1058typedef vec<4, double> vec4D;10591060typedef vec<1, float> vec1F;10611062typedef vec<2, float> vec2F;1063typedef basisu::vector<vec2F> vec2F_array;10641065typedef vec<3, float> vec3F;1066typedef basisu::vector<vec3F> vec3F_array;10671068typedef vec<4, float> vec4F;1069typedef basisu::vector<vec4F> vec4F_array;10701071typedef vec<2, uint32_t> vec2U;1072typedef vec<3, uint32_t> vec3U;1073typedef vec<2, int> vec2I;1074typedef vec<3, int> vec3I;1075typedef vec<4, int> vec4I;10761077typedef vec<2, int16_t> vec2I16;1078typedef vec<3, int16_t> vec3I16;10791080inline vec2F rotate_point_2D(const vec2F& p, float rad)1081{1082float c = cosf(rad);1083float s = sinf(rad);10841085float x = p[0];1086float y = p[1];10871088return vec2F(x * c - y * s, x * s + y * c);1089}10901091//--------------------------------------------------------------10921093// Matrix/vector cheat sheet, because confusingly, depending on how matrices are stored in memory people can use opposite definitions of "rows", "cols", etc.1094// See http://www.mindcontrol.org/~hplus/graphics/matrix-layout.html1095//1096// So in this simple row-major general matrix class:1097// matrix=[NumRows][NumCols] or [R][C], i.e. a 3x3 matrix stored in memory will appear as: R0C0, R0C1, R0C2, R1C0, R1C1, R1C2, etc.1098// Matrix multiplication: [R0,C0]*[R1,C1]=[R0,C1], C0 must equal R11099//1100// In this class:1101// A "row vector" type is a vector of size # of matrix cols, 1xC. It's the vector type that is used to store the matrix rows.1102// A "col vector" type is a vector of size # of matrix rows, Rx1. It's a vector type large enough to hold each matrix column.1103//1104// Subrow/col vectors: last component is assumed to be either 0 (a "vector") or 1 (a "point")1105// "subrow vector": vector/point of size # cols-1, 1x(C-1)1106// "subcol vector": vector/point of size # rows-1, (R-1)x11107//1108// D3D style:1109// vec*matrix, row vector on left (vec dotted against columns)1110// [1,4]*[4,4]=[1,4]1111// abcd * A B C D1112// A B C D1113// A B C D1114// A B C D1115// = e f g h1116//1117// Now confusingly, in the matrix transform method for vec*matrix below the vector's type is "col_vec", because col_vec will have the proper size for non-square matrices. But the vector on the left is written as row vector, argh.1118//1119//1120// OGL style:1121// matrix*vec, col vector on right (vec dotted against rows):1122// [4,4]*[4,1]=[4,1]1123//1124// A B C D * e = e1125// A B C D f f1126// A B C D g g1127// A B C D h h11281129template <class X, class Y, class Z>1130Z& matrix_mul_helper(Z& result, const X& lhs, const Y& rhs)1131{1132static_assert((int)Z::num_rows == (int)X::num_rows);1133static_assert((int)Z::num_cols == (int)Y::num_cols);1134static_assert((int)X::num_cols == (int)Y::num_rows);1135assert(((void*)&result != (void*)&lhs) && ((void*)&result != (void*)&rhs));1136for (int r = 0; r < X::num_rows; r++)1137for (int c = 0; c < Y::num_cols; c++)1138{1139typename Z::scalar_type s = lhs(r, 0) * rhs(0, c);1140for (uint32_t i = 1; i < X::num_cols; i++)1141s += lhs(r, i) * rhs(i, c);1142result(r, c) = s;1143}1144return result;1145}11461147template <class X, class Y, class Z>1148Z& matrix_mul_helper_transpose_lhs(Z& result, const X& lhs, const Y& rhs)1149{1150static_assert((int)Z::num_rows == (int)X::num_cols);1151static_assert((int)Z::num_cols == (int)Y::num_cols);1152static_assert((int)X::num_rows == (int)Y::num_rows);1153assert(((void*)&result != (void*)&lhs) && ((void*)&result != (void*)&rhs));1154for (int r = 0; r < X::num_cols; r++)1155for (int c = 0; c < Y::num_cols; c++)1156{1157typename Z::scalar_type s = lhs(0, r) * rhs(0, c);1158for (uint32_t i = 1; i < X::num_rows; i++)1159s += lhs(i, r) * rhs(i, c);1160result(r, c) = s;1161}1162return result;1163}11641165template <class X, class Y, class Z>1166Z& matrix_mul_helper_transpose_rhs(Z& result, const X& lhs, const Y& rhs)1167{1168static_assert((int)Z::num_rows == (int)X::num_rows);1169static_assert((int)Z::num_cols == (int)Y::num_rows);1170static_assert((int)X::num_cols == (int)Y::num_cols);1171assert(((void*)&result != (void*)&lhs) && ((void*)&result != (void*)&rhs));1172for (int r = 0; r < X::num_rows; r++)1173for (int c = 0; c < Y::num_rows; c++)1174{1175typename Z::scalar_type s = lhs(r, 0) * rhs(c, 0);1176for (uint32_t i = 1; i < X::num_cols; i++)1177s += lhs(r, i) * rhs(c, i);1178result(r, c) = s;1179}1180return result;1181}11821183template <uint32_t R, uint32_t C, typename T>1184class matrix1185{1186public:1187typedef T scalar_type;1188enum1189{1190num_rows = R,1191num_cols = C1192};11931194typedef vec<R, T> col_vec;1195typedef vec < (R > 1) ? (R - 1) : 0, T > subcol_vec;11961197typedef vec<C, T> row_vec;1198typedef vec < (C > 1) ? (C - 1) : 0, T > subrow_vec;11991200inline matrix()1201{1202}12031204inline matrix(basisu::eClear)1205{1206clear();1207}12081209inline matrix(basisu::eIdentity)1210{1211set_identity_matrix();1212}12131214inline matrix(const T* p)1215{1216set(p);1217}12181219inline matrix(const matrix& other)1220{1221for (uint32_t i = 0; i < R; i++)1222m_rows[i] = other.m_rows[i];1223}12241225inline matrix& operator=(const matrix& rhs)1226{1227if (this != &rhs)1228for (uint32_t i = 0; i < R; i++)1229m_rows[i] = rhs.m_rows[i];1230return *this;1231}12321233inline matrix(T val00, T val01,1234T val10, T val11)1235{1236set(val00, val01, val10, val11);1237}12381239inline matrix(T val00, T val01,1240T val10, T val11,1241T val20, T val21)1242{1243set(val00, val01, val10, val11, val20, val21);1244}12451246inline matrix(T val00, T val01, T val02,1247T val10, T val11, T val12,1248T val20, T val21, T val22)1249{1250set(val00, val01, val02, val10, val11, val12, val20, val21, val22);1251}12521253inline matrix(T val00, T val01, T val02, T val03,1254T val10, T val11, T val12, T val13,1255T val20, T val21, T val22, T val23,1256T val30, T val31, T val32, T val33)1257{1258set(val00, val01, val02, val03, val10, val11, val12, val13, val20, val21, val22, val23, val30, val31, val32, val33);1259}12601261inline matrix(T val00, T val01, T val02, T val03,1262T val10, T val11, T val12, T val13,1263T val20, T val21, T val22, T val23)1264{1265set(val00, val01, val02, val03, val10, val11, val12, val13, val20, val21, val22, val23);1266}12671268inline void set(const float* p)1269{1270for (uint32_t i = 0; i < R; i++)1271{1272m_rows[i].set(p);1273p += C;1274}1275}12761277inline void set(T val00, T val01,1278T val10, T val11)1279{1280m_rows[0].set(val00, val01);1281if (R >= 2)1282{1283m_rows[1].set(val10, val11);12841285for (uint32_t i = 2; i < R; i++)1286m_rows[i].clear();1287}1288}12891290inline void set(T val00, T val01,1291T val10, T val11,1292T val20, T val21)1293{1294m_rows[0].set(val00, val01);1295if (R >= 2)1296{1297m_rows[1].set(val10, val11);12981299if (R >= 3)1300{1301m_rows[2].set(val20, val21);13021303for (uint32_t i = 3; i < R; i++)1304m_rows[i].clear();1305}1306}1307}13081309inline void set(T val00, T val01, T val02,1310T val10, T val11, T val12,1311T val20, T val21, T val22)1312{1313m_rows[0].set(val00, val01, val02);1314if (R >= 2)1315{1316m_rows[1].set(val10, val11, val12);1317if (R >= 3)1318{1319m_rows[2].set(val20, val21, val22);13201321for (uint32_t i = 3; i < R; i++)1322m_rows[i].clear();1323}1324}1325}13261327inline void set(T val00, T val01, T val02, T val03,1328T val10, T val11, T val12, T val13,1329T val20, T val21, T val22, T val23,1330T val30, T val31, T val32, T val33)1331{1332m_rows[0].set(val00, val01, val02, val03);1333if (R >= 2)1334{1335m_rows[1].set(val10, val11, val12, val13);1336if (R >= 3)1337{1338m_rows[2].set(val20, val21, val22, val23);13391340if (R >= 4)1341{1342m_rows[3].set(val30, val31, val32, val33);13431344for (uint32_t i = 4; i < R; i++)1345m_rows[i].clear();1346}1347}1348}1349}13501351inline void set(T val00, T val01, T val02, T val03,1352T val10, T val11, T val12, T val13,1353T val20, T val21, T val22, T val23)1354{1355m_rows[0].set(val00, val01, val02, val03);1356if (R >= 2)1357{1358m_rows[1].set(val10, val11, val12, val13);1359if (R >= 3)1360{1361m_rows[2].set(val20, val21, val22, val23);13621363for (uint32_t i = 3; i < R; i++)1364m_rows[i].clear();1365}1366}1367}13681369inline uint32_t get_num_rows() const1370{1371return num_rows;1372}13731374inline uint32_t get_num_cols() const1375{1376return num_cols;1377}13781379inline uint32_t get_total_elements() const1380{1381return num_rows * num_cols;1382}13831384inline T operator()(uint32_t r, uint32_t c) const1385{1386assert((r < R) && (c < C));1387return m_rows[r][c];1388}13891390inline T& operator()(uint32_t r, uint32_t c)1391{1392assert((r < R) && (c < C));1393return m_rows[r][c];1394}13951396inline const row_vec& operator[](uint32_t r) const1397{1398assert(r < R);1399return m_rows[r];1400}14011402inline row_vec& operator[](uint32_t r)1403{1404assert(r < R);1405return m_rows[r];1406}14071408inline const row_vec& get_row(uint32_t r) const1409{1410return (*this)[r];1411}14121413inline row_vec& get_row(uint32_t r)1414{1415return (*this)[r];1416}14171418inline void set_row(uint32_t r, const row_vec& v)1419{1420(*this)[r] = v;1421}14221423inline col_vec get_col(uint32_t c) const1424{1425assert(c < C);1426col_vec result;1427for (uint32_t i = 0; i < R; i++)1428result[i] = m_rows[i][c];1429return result;1430}14311432inline void set_col(uint32_t c, const col_vec& col)1433{1434assert(c < C);1435for (uint32_t i = 0; i < R; i++)1436m_rows[i][c] = col[i];1437}14381439inline void set_col(uint32_t c, const subcol_vec& col)1440{1441assert(c < C);1442for (uint32_t i = 0; i < (R - 1); i++)1443m_rows[i][c] = col[i];14441445m_rows[R - 1][c] = 0.0f;1446}14471448inline const row_vec& get_translate() const1449{1450return m_rows[R - 1];1451}14521453inline matrix& set_translate(const row_vec& r)1454{1455m_rows[R - 1] = r;1456return *this;1457}14581459inline matrix& set_translate(const subrow_vec& r)1460{1461m_rows[R - 1] = row_vec(r).as_point();1462return *this;1463}14641465inline const T* get_ptr() const1466{1467return reinterpret_cast<const T*>(&m_rows[0]);1468}1469inline T* get_ptr()1470{1471return reinterpret_cast<T*>(&m_rows[0]);1472}14731474inline matrix& operator+=(const matrix& other)1475{1476for (uint32_t i = 0; i < R; i++)1477m_rows[i] += other.m_rows[i];1478return *this;1479}14801481inline matrix& operator-=(const matrix& other)1482{1483for (uint32_t i = 0; i < R; i++)1484m_rows[i] -= other.m_rows[i];1485return *this;1486}14871488inline matrix& operator*=(T val)1489{1490for (uint32_t i = 0; i < R; i++)1491m_rows[i] *= val;1492return *this;1493}14941495inline matrix& operator/=(T val)1496{1497for (uint32_t i = 0; i < R; i++)1498m_rows[i] /= val;1499return *this;1500}15011502inline matrix& operator*=(const matrix& other)1503{1504matrix result;1505matrix_mul_helper(result, *this, other);1506*this = result;1507return *this;1508}15091510friend inline matrix operator+(const matrix& lhs, const matrix& rhs)1511{1512matrix result;1513for (uint32_t i = 0; i < R; i++)1514result[i] = lhs.m_rows[i] + rhs.m_rows[i];1515return result;1516}15171518friend inline matrix operator-(const matrix& lhs, const matrix& rhs)1519{1520matrix result;1521for (uint32_t i = 0; i < R; i++)1522result[i] = lhs.m_rows[i] - rhs.m_rows[i];1523return result;1524}15251526friend inline matrix operator*(const matrix& lhs, T val)1527{1528matrix result;1529for (uint32_t i = 0; i < R; i++)1530result[i] = lhs.m_rows[i] * val;1531return result;1532}15331534friend inline matrix operator/(const matrix& lhs, T val)1535{1536matrix result;1537for (uint32_t i = 0; i < R; i++)1538result[i] = lhs.m_rows[i] / val;1539return result;1540}15411542friend inline matrix operator*(T val, const matrix& rhs)1543{1544matrix result;1545for (uint32_t i = 0; i < R; i++)1546result[i] = val * rhs.m_rows[i];1547return result;1548}15491550#if 01551template<uint32_t R0, uint32_t C0, uint32_t R1, uint32_t C1, typename T>1552friend inline matrix operator*(const matrix<R0, C0, T>& lhs, const matrix<R1, C1, T>& rhs)1553{1554matrix<R0, C1, T> result;1555return matrix_mul_helper(result, lhs, rhs);1556}1557#endif1558friend inline matrix operator*(const matrix& lhs, const matrix& rhs)1559{1560matrix result;1561return matrix_mul_helper(result, lhs, rhs);1562}15631564friend inline row_vec operator*(const col_vec& a, const matrix& b)1565{1566return transform(a, b);1567}15681569inline matrix operator+() const1570{1571return *this;1572}15731574inline matrix operator-() const1575{1576matrix result;1577for (uint32_t i = 0; i < R; i++)1578result[i] = -m_rows[i];1579return result;1580}15811582inline matrix& clear()1583{1584for (uint32_t i = 0; i < R; i++)1585m_rows[i].clear();1586return *this;1587}15881589inline matrix& set_zero_matrix()1590{1591clear();1592return *this;1593}15941595inline matrix& set_identity_matrix()1596{1597for (uint32_t i = 0; i < R; i++)1598{1599m_rows[i].clear();1600m_rows[i][i] = 1.0f;1601}1602return *this;1603}16041605inline matrix& set_scale_matrix(float s)1606{1607clear();1608for (int i = 0; i < (R - 1); i++)1609m_rows[i][i] = s;1610m_rows[R - 1][C - 1] = 1.0f;1611return *this;1612}16131614inline matrix& set_scale_matrix(const row_vec& s)1615{1616clear();1617for (uint32_t i = 0; i < R; i++)1618m_rows[i][i] = s[i];1619return *this;1620}16211622inline matrix& set_scale_matrix(float x, float y)1623{1624set_identity_matrix();1625m_rows[0].set_x(x);1626m_rows[1].set_y(y);1627return *this;1628}16291630inline matrix& set_scale_matrix(float x, float y, float z)1631{1632set_identity_matrix();1633m_rows[0].set_x(x);1634m_rows[1].set_y(y);1635m_rows[2].set_z(z);1636return *this;1637}16381639inline matrix& set_translate_matrix(const row_vec& s)1640{1641set_identity_matrix();1642set_translate(s);1643return *this;1644}16451646inline matrix& set_translate_matrix(float x, float y)1647{1648set_identity_matrix();1649set_translate(row_vec(x, y).as_point());1650return *this;1651}16521653inline matrix& set_translate_matrix(float x, float y, float z)1654{1655set_identity_matrix();1656set_translate(row_vec(x, y, z).as_point());1657return *this;1658}16591660inline matrix get_transposed() const1661{1662static_assert(R == C);16631664matrix result;1665for (uint32_t i = 0; i < R; i++)1666for (uint32_t j = 0; j < C; j++)1667result.m_rows[i][j] = m_rows[j][i];1668return result;1669}16701671inline matrix<C, R, T> get_transposed_nonsquare() const1672{1673matrix<C, R, T> result;1674for (uint32_t i = 0; i < R; i++)1675for (uint32_t j = 0; j < C; j++)1676result[j][i] = m_rows[i][j];1677return result;1678}16791680inline matrix& transpose_in_place()1681{1682matrix result;1683for (uint32_t i = 0; i < R; i++)1684for (uint32_t j = 0; j < C; j++)1685result.m_rows[i][j] = m_rows[j][i];1686*this = result;1687return *this;1688}16891690// Frobenius Norm1691T get_norm() const1692{1693T result = 0;16941695for (uint32_t i = 0; i < R; i++)1696for (uint32_t j = 0; j < C; j++)1697result += m_rows[i][j] * m_rows[i][j];16981699return static_cast<T>(sqrt(result));1700}17011702inline matrix get_power(T p) const1703{1704matrix result;17051706for (uint32_t i = 0; i < R; i++)1707for (uint32_t j = 0; j < C; j++)1708result[i][j] = static_cast<T>(pow(m_rows[i][j], p));17091710return result;1711}17121713inline matrix<1, R, T> numpy_dot(const matrix<1, C, T>& b) const1714{1715matrix<1, R, T> result;17161717for (uint32_t r = 0; r < R; r++)1718{1719T sum = 0;1720for (uint32_t c = 0; c < C; c++)1721sum += m_rows[r][c] * b[0][c];17221723result[0][r] = static_cast<T>(sum);1724}17251726return result;1727}17281729bool invert(matrix& result) const1730{1731static_assert(R == C);17321733result.set_identity_matrix();17341735matrix mat(*this);17361737for (uint32_t c = 0; c < C; c++)1738{1739uint32_t max_r = c;1740for (uint32_t r = c + 1; r < R; r++)1741if (fabs(mat[r][c]) > fabs(mat[max_r][c]))1742max_r = r;17431744if (mat[max_r][c] == 0.0f)1745{1746result.set_identity_matrix();1747return false;1748}17491750std::swap(mat[c], mat[max_r]);1751std::swap(result[c], result[max_r]);17521753result[c] /= mat[c][c];1754mat[c] /= mat[c][c];17551756for (uint32_t row = 0; row < R; row++)1757{1758if (row != c)1759{1760const row_vec temp(mat[row][c]);1761mat[row] -= row_vec::mul_components(mat[c], temp);1762result[row] -= row_vec::mul_components(result[c], temp);1763}1764}1765}17661767return true;1768}17691770matrix& invert_in_place()1771{1772matrix result;1773invert(result);1774*this = result;1775return *this;1776}17771778matrix get_inverse() const1779{1780matrix result;1781invert(result);1782return result;1783}17841785T get_det() const1786{1787static_assert(R == C);1788return det_helper(*this, R);1789}17901791bool equal_tol(const matrix& b, float tol) const1792{1793for (uint32_t r = 0; r < R; r++)1794if (!row_vec::equal_tol(m_rows[r], b.m_rows[r], tol))1795return false;1796return true;1797}17981799bool is_square() const1800{1801return R == C;1802}18031804double get_trace() const1805{1806static_assert(is_square());18071808T total = 0;1809for (uint32_t i = 0; i < R; i++)1810total += (*this)(i, i);18111812return total;1813}18141815void print() const1816{1817for (uint32_t r = 0; r < R; r++)1818{1819for (uint32_t c = 0; c < C; c++)1820printf("%3.7f ", (*this)(r, c));1821printf("\n");1822}1823}18241825// This method transforms a vec by a matrix (D3D-style: row vector on left).1826// Confusingly, note that the data type is named "col_vec", but mathematically it's actually written as a row vector (of size equal to the # matrix rows, which is why it's called a "col_vec" in this class).1827// 1xR * RxC = 1xC1828// This dots against the matrix columns.1829static inline row_vec transform(const col_vec& a, const matrix& b)1830{1831row_vec result(b[0] * a[0]);1832for (uint32_t r = 1; r < R; r++)1833result += b[r] * a[r];1834return result;1835}18361837// This method transforms a vec by a matrix (D3D-style: row vector on left).1838// Last component of vec is assumed to be 1.1839static inline row_vec transform_point(const col_vec& a, const matrix& b)1840{1841row_vec result(0);1842for (int r = 0; r < (R - 1); r++)1843result += b[r] * a[r];1844result += b[R - 1];1845return result;1846}18471848// This method transforms a vec by a matrix (D3D-style: row vector on left).1849// Last component of vec is assumed to be 0.1850static inline row_vec transform_vector(const col_vec& a, const matrix& b)1851{1852row_vec result(0);1853for (int r = 0; r < (R - 1); r++)1854result += b[r] * a[r];1855return result;1856}18571858// This method transforms a vec by a matrix (D3D-style: row vector on left).1859// Last component of vec is assumed to be 1.1860static inline subcol_vec transform_point(const subcol_vec& a, const matrix& b)1861{1862subcol_vec result(0);1863for (int r = 0; r < static_cast<int>(R); r++)1864{1865const T s = (r < subcol_vec::num_elements) ? a[r] : 1.0f;1866for (int c = 0; c < static_cast<int>(C - 1); c++)1867result[c] += b[r][c] * s;1868}1869return result;1870}18711872// This method transforms a vec by a matrix (D3D-style: row vector on left).1873// Last component of vec is assumed to be 0.1874static inline subcol_vec transform_vector(const subcol_vec& a, const matrix& b)1875{1876subcol_vec result(0);1877for (int r = 0; r < static_cast<int>(R - 1); r++)1878{1879const T s = a[r];1880for (int c = 0; c < static_cast<int>(C - 1); c++)1881result[c] += b[r][c] * s;1882}1883return result;1884}18851886// Like transform() above, but the matrix is effectively transposed before the multiply.1887static inline col_vec transform_transposed(const col_vec& a, const matrix& b)1888{1889static_assert(R == C);1890col_vec result;1891for (uint32_t r = 0; r < R; r++)1892result[r] = b[r].dot(a);1893return result;1894}18951896// Like transform() above, but the matrix is effectively transposed before the multiply.1897// Last component of vec is assumed to be 0.1898static inline col_vec transform_vector_transposed(const col_vec& a, const matrix& b)1899{1900static_assert(R == C);1901col_vec result;1902for (uint32_t r = 0; r < R; r++)1903{1904T s = 0;1905for (uint32_t c = 0; c < (C - 1); c++)1906s += b[r][c] * a[c];19071908result[r] = s;1909}1910return result;1911}19121913// This method transforms a vec by a matrix (D3D-style: row vector on left), but the matrix is effectively transposed before the multiply.1914// Last component of vec is assumed to be 1.1915static inline subcol_vec transform_point_transposed(const subcol_vec& a, const matrix& b)1916{1917static_assert(R == C);1918subcol_vec result(0);1919for (int r = 0; r < R; r++)1920{1921const T s = (r < subcol_vec::num_elements) ? a[r] : 1.0f;1922for (int c = 0; c < (C - 1); c++)1923result[c] += b[c][r] * s;1924}1925return result;1926}19271928// This method transforms a vec by a matrix (D3D-style: row vector on left), but the matrix is effectively transposed before the multiply.1929// Last component of vec is assumed to be 0.1930static inline subcol_vec transform_vector_transposed(const subcol_vec& a, const matrix& b)1931{1932static_assert(R == C);1933subcol_vec result(0);1934for (int r = 0; r < static_cast<int>(R - 1); r++)1935{1936const T s = a[r];1937for (int c = 0; c < static_cast<int>(C - 1); c++)1938result[c] += b[c][r] * s;1939}1940return result;1941}19421943// This method transforms a matrix by a vector (OGL style, col vector on the right).1944// Note that the data type is named "row_vec", but mathematically it's actually written as a column vector (of size equal to the # matrix cols).1945// RxC * Cx1 = Rx11946// This dots against the matrix rows.1947static inline col_vec transform(const matrix& b, const row_vec& a)1948{1949col_vec result;1950for (int r = 0; r < static_cast<int>(R); r++)1951result[r] = b[r].dot(a);1952return result;1953}19541955// This method transforms a matrix by a vector (OGL style, col vector on the right), except the matrix is effectively transposed before the multiply.1956// Note that the data type is named "row_vec", but mathematically it's actually written as a column vector (of size equal to the # matrix cols).1957// RxC * Cx1 = Rx11958// This dots against the matrix cols.1959static inline col_vec transform_transposed(const matrix& b, const row_vec& a)1960{1961static_assert(R == C);1962row_vec result(b[0] * a[0]);1963for (int r = 1; r < static_cast<int>(R); r++)1964result += b[r] * a[r];1965return col_vec(result);1966}19671968static inline matrix& mul_components(matrix& result, const matrix& lhs, const matrix& rhs)1969{1970for (uint32_t r = 0; r < R; r++)1971result[r] = row_vec::mul_components(lhs[r], rhs[r]);1972return result;1973}19741975static inline matrix& concat(matrix& lhs, const matrix& rhs)1976{1977return matrix_mul_helper(lhs, matrix(lhs), rhs);1978}19791980inline matrix& concat_in_place(const matrix& rhs)1981{1982return concat(*this, rhs);1983}19841985static inline matrix& multiply(matrix& result, const matrix& lhs, const matrix& rhs)1986{1987matrix temp;1988matrix* pResult = ((&result == &lhs) || (&result == &rhs)) ? &temp : &result;19891990matrix_mul_helper(*pResult, lhs, rhs);1991if (pResult != &result)1992result = *pResult;19931994return result;1995}19961997static matrix make_zero_matrix()1998{1999matrix result;2000result.clear();2001return result;2002}20032004static matrix make_identity_matrix()2005{2006matrix result;2007result.set_identity_matrix();2008return result;2009}20102011static matrix make_translate_matrix(const row_vec& t)2012{2013return matrix(basisu::cIdentity).set_translate(t);2014}20152016static matrix make_translate_matrix(float x, float y)2017{2018return matrix(basisu::cIdentity).set_translate_matrix(x, y);2019}20202021static matrix make_translate_matrix(float x, float y, float z)2022{2023return matrix(basisu::cIdentity).set_translate_matrix(x, y, z);2024}20252026static inline matrix make_scale_matrix(float s)2027{2028return matrix().set_scale_matrix(s);2029}20302031static inline matrix make_scale_matrix(const row_vec& s)2032{2033return matrix().set_scale_matrix(s);2034}20352036static inline matrix make_scale_matrix(float x, float y)2037{2038static_assert(R >= 3 && C >= 3);2039matrix result;2040result.set_identity_matrix();2041result.m_rows[0][0] = x;2042result.m_rows[1][1] = y;2043return result;2044}20452046static inline matrix make_scale_matrix(float x, float y, float z)2047{2048static_assert(R >= 4 && C >= 4);2049matrix result;2050result.set_identity_matrix();2051result.m_rows[0][0] = x;2052result.m_rows[1][1] = y;2053result.m_rows[2][2] = z;2054return result;2055}20562057// Helpers derived from Graphics Gems 1 and 2 (Matrices and Transformations, Ronald N. Goldman)2058static matrix make_rotate_matrix(const vec<3, T>& axis, T ang)2059{2060static_assert(R >= 3 && C >= 3);20612062vec<3, T> norm_axis(axis.get_normalized());20632064double cos_a = cos(ang);2065double inv_cos_a = 1.0f - cos_a;20662067double sin_a = sin(ang);20682069const T x = norm_axis[0];2070const T y = norm_axis[1];2071const T z = norm_axis[2];20722073const double x2 = norm_axis[0] * norm_axis[0];2074const double y2 = norm_axis[1] * norm_axis[1];2075const double z2 = norm_axis[2] * norm_axis[2];20762077matrix result;2078result.set_identity_matrix();20792080result[0][0] = (T)((inv_cos_a * x2) + cos_a);2081result[1][0] = (T)((inv_cos_a * x * y) + (sin_a * z));2082result[2][0] = (T)((inv_cos_a * x * z) - (sin_a * y));20832084result[0][1] = (T)((inv_cos_a * x * y) - (sin_a * z));2085result[1][1] = (T)((inv_cos_a * y2) + cos_a);2086result[2][1] = (T)((inv_cos_a * y * z) + (sin_a * x));20872088result[0][2] = (T)((inv_cos_a * x * z) + (sin_a * y));2089result[1][2] = (T)((inv_cos_a * y * z) - (sin_a * x));2090result[2][2] = (T)((inv_cos_a * z2) + cos_a);20912092return result;2093}20942095static inline matrix make_rotate_matrix(T ang)2096{2097static_assert(R >= 2 && C >= 2);20982099matrix ret(basisu::cIdentity);21002101const T sin_a = static_cast<T>(sin(ang));2102const T cos_a = static_cast<T>(cos(ang));21032104ret[0][0] = +cos_a;2105ret[0][1] = -sin_a;2106ret[1][0] = +sin_a;2107ret[1][1] = +cos_a;21082109return ret;2110}21112112static inline matrix make_rotate_matrix(uint32_t axis, T ang)2113{2114vec<3, T> axis_vec;2115axis_vec.clear();2116axis_vec[axis] = 1.0f;2117return make_rotate_matrix(axis_vec, ang);2118}21192120static inline matrix make_cross_product_matrix(const vec<3, scalar_type>& c)2121{2122static_assert((num_rows >= 3) && (num_cols >= 3));2123matrix ret(basisu::cClear);2124ret[0][1] = c[2];2125ret[0][2] = -c[1];2126ret[1][0] = -c[2];2127ret[1][2] = c[0];2128ret[2][0] = c[1];2129ret[2][1] = -c[0];2130return ret;2131}21322133static inline matrix make_reflection_matrix(const vec<4, scalar_type>& n, const vec<4, scalar_type>& q)2134{2135static_assert((num_rows == 4) && (num_cols == 4));2136matrix ret;2137assert(n.is_vector() && q.is_vector());2138ret = make_identity_matrix() - 2.0f * make_tensor_product_matrix(n, n);2139ret.set_translate((2.0f * q.dot(n) * n).as_point());2140return ret;2141}21422143static inline matrix make_tensor_product_matrix(const row_vec& v, const row_vec& w)2144{2145matrix ret;2146for (int r = 0; r < num_rows; r++)2147ret[r] = row_vec::mul_components(v.broadcast(r), w);2148return ret;2149}21502151static inline matrix make_uniform_scaling_matrix(const vec<4, scalar_type>& q, scalar_type c)2152{2153static_assert((num_rows == 4) && (num_cols == 4));2154assert(q.is_vector());2155matrix ret;2156ret = c * make_identity_matrix();2157ret.set_translate(((1.0f - c) * q).as_point());2158return ret;2159}21602161static inline matrix make_nonuniform_scaling_matrix(const vec<4, scalar_type>& q, scalar_type c, const vec<4, scalar_type>& w)2162{2163static_assert((num_rows == 4) && (num_cols == 4));2164assert(q.is_vector() && w.is_vector());2165matrix ret;2166ret = make_identity_matrix() - (1.0f - c) * make_tensor_product_matrix(w, w);2167ret.set_translate(((1.0f - c) * q.dot(w) * w).as_point());2168return ret;2169}21702171// n = normal of plane, q = point on plane2172static inline matrix make_ortho_projection_matrix(const vec<4, scalar_type>& n, const vec<4, scalar_type>& q)2173{2174assert(n.is_vector() && q.is_vector());2175matrix ret;2176ret = make_identity_matrix() - make_tensor_product_matrix(n, n);2177ret.set_translate((q.dot(n) * n).as_point());2178return ret;2179}21802181static inline matrix make_parallel_projection(const vec<4, scalar_type>& n, const vec<4, scalar_type>& q, const vec<4, scalar_type>& w)2182{2183assert(n.is_vector() && q.is_vector() && w.is_vector());2184matrix ret;2185ret = make_identity_matrix() - (make_tensor_product_matrix(n, w) / (w.dot(n)));2186ret.set_translate(((q.dot(n) / w.dot(n)) * w).as_point());2187return ret;2188}21892190protected:2191row_vec m_rows[R];21922193static T det_helper(const matrix& a, uint32_t n)2194{2195// Algorithm ported from Numerical Recipes in C.2196T d;2197matrix m;2198if (n == 2)2199d = a(0, 0) * a(1, 1) - a(1, 0) * a(0, 1);2200else2201{2202d = 0;2203for (uint32_t j1 = 1; j1 <= n; j1++)2204{2205for (uint32_t i = 2; i <= n; i++)2206{2207int j2 = 1;2208for (uint32_t j = 1; j <= n; j++)2209{2210if (j != j1)2211{2212m(i - 2, j2 - 1) = a(i - 1, j - 1);2213j2++;2214}2215}2216}2217d += (((1 + j1) & 1) ? -1.0f : 1.0f) * a(1 - 1, j1 - 1) * det_helper(m, n - 1);2218}2219}2220return d;2221}2222};22232224typedef matrix<2, 2, float> matrix22F;2225typedef matrix<2, 2, double> matrix22D;22262227typedef matrix<3, 3, float> matrix33F;2228typedef matrix<3, 3, double> matrix33D;22292230typedef matrix<4, 4, float> matrix44F;2231typedef matrix<4, 4, double> matrix44D;22322233typedef matrix<8, 8, float> matrix88F;22342235// These helpers create good old D3D-style matrices.2236inline matrix44F matrix44F_make_perspective_offcenter_lh(float l, float r, float b, float t, float nz, float fz)2237{2238float two_nz = 2.0f * nz;2239float one_over_width = 1.0f / (r - l);2240float one_over_height = 1.0f / (t - b);22412242matrix44F view_to_proj;2243view_to_proj[0].set(two_nz * one_over_width, 0.0f, 0.0f, 0.0f);2244view_to_proj[1].set(0.0f, two_nz * one_over_height, 0.0f, 0.0f);2245view_to_proj[2].set(-(l + r) * one_over_width, -(t + b) * one_over_height, fz / (fz - nz), 1.0f);2246view_to_proj[3].set(0.0f, 0.0f, -view_to_proj[2][2] * nz, 0.0f);2247return view_to_proj;2248}22492250// fov_y: full Y field of view (radians)2251// aspect: viewspace width/height2252inline matrix44F matrix44F_make_perspective_fov_lh(float fov_y, float aspect, float nz, float fz)2253{2254double sin_fov = sin(0.5f * fov_y);2255double cos_fov = cos(0.5f * fov_y);22562257float y_scale = static_cast<float>(cos_fov / sin_fov);2258float x_scale = static_cast<float>(y_scale / aspect);22592260matrix44F view_to_proj;2261view_to_proj[0].set(x_scale, 0, 0, 0);2262view_to_proj[1].set(0, y_scale, 0, 0);2263view_to_proj[2].set(0, 0, fz / (fz - nz), 1);2264view_to_proj[3].set(0, 0, -nz * fz / (fz - nz), 0);2265return view_to_proj;2266}22672268inline matrix44F matrix44F_make_ortho_offcenter_lh(float l, float r, float b, float t, float nz, float fz)2269{2270matrix44F view_to_proj;2271view_to_proj[0].set(2.0f / (r - l), 0.0f, 0.0f, 0.0f);2272view_to_proj[1].set(0.0f, 2.0f / (t - b), 0.0f, 0.0f);2273view_to_proj[2].set(0.0f, 0.0f, 1.0f / (fz - nz), 0.0f);2274view_to_proj[3].set((l + r) / (l - r), (t + b) / (b - t), nz / (nz - fz), 1.0f);2275return view_to_proj;2276}22772278inline matrix44F matrix44F_make_ortho_lh(float w, float h, float nz, float fz)2279{2280return matrix44F_make_ortho_offcenter_lh(-w * .5f, w * .5f, -h * .5f, h * .5f, nz, fz);2281}22822283inline matrix44F matrix44F_make_projection_to_screen_d3d(int x, int y, int w, int h, float min_z, float max_z)2284{2285matrix44F proj_to_screen;2286proj_to_screen[0].set(w * .5f, 0.0f, 0.0f, 0.0f);2287proj_to_screen[1].set(0, h * -.5f, 0.0f, 0.0f);2288proj_to_screen[2].set(0, 0.0f, max_z - min_z, 0.0f);2289proj_to_screen[3].set(x + w * .5f, y + h * .5f, min_z, 1.0f);2290return proj_to_screen;2291}22922293inline matrix44F matrix44F_make_lookat_lh(const vec3F& camera_pos, const vec3F& look_at, const vec3F& camera_up, float camera_roll_ang_in_radians)2294{2295vec4F col2(look_at - camera_pos);2296assert(col2.is_vector());2297if (col2.normalize() == 0.0f)2298col2.set(0, 0, 1, 0);22992300vec4F col1(camera_up);2301assert(col1.is_vector());2302if (!col2[0] && !col2[2])2303col1.set(-1.0f, 0.0f, 0.0f, 0.0f);23042305if ((col1.dot(col2)) > .9999f)2306col1.set(0.0f, 1.0f, 0.0f, 0.0f);23072308vec4F col0(vec4F::cross3(col1, col2).normalize_in_place());2309col1 = vec4F::cross3(col2, col0).normalize_in_place();23102311matrix44F rotm(matrix44F::make_identity_matrix());2312rotm.set_col(0, col0);2313rotm.set_col(1, col1);2314rotm.set_col(2, col2);2315return matrix44F::make_translate_matrix(-camera_pos[0], -camera_pos[1], -camera_pos[2]) * rotm * matrix44F::make_rotate_matrix(2, camera_roll_ang_in_radians);2316}23172318template<typename R> R matrix_NxN_create_DCT()2319{2320assert(R::num_rows == R::num_cols);23212322const uint32_t N = R::num_cols;23232324R result;2325for (uint32_t k = 0; k < N; k++)2326{2327for (uint32_t n = 0; n < N; n++)2328{2329double f;23302331if (!k)2332f = 1.0f / sqrt(float(N));2333else2334f = sqrt(2.0f / float(N)) * cos((basisu::cPiD * (2.0f * float(n) + 1.0f) * float(k)) / (2.0f * float(N)));23352336result(k, n) = static_cast<typename R::scalar_type>(f);2337}2338}23392340return result;2341}23422343template<typename R> R matrix_NxN_DCT(const R& a, const R& dct)2344{2345R temp;2346matrix_mul_helper<R, R, R>(temp, dct, a);2347R result;2348matrix_mul_helper_transpose_rhs<R, R, R>(result, temp, dct);2349return result;2350}23512352template<typename R> R matrix_NxN_IDCT(const R& b, const R& dct)2353{2354R temp;2355matrix_mul_helper_transpose_lhs<R, R, R>(temp, dct, b);2356R result;2357matrix_mul_helper<R, R, R>(result, temp, dct);2358return result;2359}23602361template<typename X, typename Y> matrix<X::num_rows* Y::num_rows, X::num_cols* Y::num_cols, typename X::scalar_type> matrix_kronecker_product(const X& a, const Y& b)2362{2363matrix<X::num_rows* Y::num_rows, X::num_cols* Y::num_cols, typename X::scalar_type> result;23642365for (uint32_t r = 0; r < X::num_rows; r++)2366{2367for (uint32_t c = 0; c < X::num_cols; c++)2368{2369for (uint32_t i = 0; i < Y::num_rows; i++)2370for (uint32_t j = 0; j < Y::num_cols; j++)2371result(r * Y::num_rows + i, c * Y::num_cols + j) = a(r, c) * b(i, j);2372}2373}23742375return result;2376}23772378template<typename X, typename Y> matrix<X::num_rows + Y::num_rows, X::num_cols, typename X::scalar_type> matrix_combine_vertically(const X& a, const Y& b)2379{2380matrix<X::num_rows + Y::num_rows, X::num_cols, typename X::scalar_type> result;23812382for (uint32_t r = 0; r < X::num_rows; r++)2383for (uint32_t c = 0; c < X::num_cols; c++)2384result(r, c) = a(r, c);23852386for (uint32_t r = 0; r < Y::num_rows; r++)2387for (uint32_t c = 0; c < Y::num_cols; c++)2388result(r + X::num_rows, c) = b(r, c);23892390return result;2391}23922393inline matrix88F get_haar8()2394{2395matrix22F haar2(23961, 1,23971, -1);2398matrix22F i2(23991, 0,24000, 1);2401matrix44F i4(24021, 0, 0, 0,24030, 1, 0, 0,24040, 0, 1, 0,24050, 0, 0, 1);24062407matrix<1, 2, float> b0; b0(0, 0) = 1; b0(0, 1) = 1;2408matrix<1, 2, float> b1; b1(0, 0) = 1.0f; b1(0, 1) = -1.0f;24092410matrix<2, 4, float> haar4_0 = matrix_kronecker_product(haar2, b0);2411matrix<2, 4, float> haar4_1 = matrix_kronecker_product(i2, b1);24122413matrix<4, 4, float> haar4 = matrix_combine_vertically(haar4_0, haar4_1);24142415matrix<4, 8, float> haar8_0 = matrix_kronecker_product(haar4, b0);2416matrix<4, 8, float> haar8_1 = matrix_kronecker_product(i4, b1);24172418haar8_0[2] *= sqrtf(2);2419haar8_0[3] *= sqrtf(2);2420haar8_1 *= 2.0f;24212422matrix<8, 8, float> haar8 = matrix_combine_vertically(haar8_0, haar8_1);24232424return haar8;2425}24262427inline matrix44F get_haar4()2428{2429const float sqrt2 = 1.4142135623730951f;24302431return matrix44F(2432.5f * 1, .5f * 1, .5f * 1, .5f * 1,2433.5f * 1, .5f * 1, .5f * -1, .5f * -1,2434.5f * sqrt2, .5f * -sqrt2, 0, 0,24350, 0, .5f * sqrt2, .5f * -sqrt2);2436}24372438template<typename T>2439inline matrix<2, 2, T> get_inverse_2x2(const matrix<2, 2, T>& m)2440{2441double a = m[0][0];2442double b = m[0][1];2443double c = m[1][0];2444double d = m[1][1];24452446double det = a * d - b * c;2447if (det != 0.0f)2448det = 1.0f / det;24492450matrix<2, 2, T> result;2451result[0][0] = static_cast<T>(d * det);2452result[0][1] = static_cast<T>(-b * det);2453result[1][0] = static_cast<T>(-c * det);2454result[1][1] = static_cast<T>(a * det);2455return result;2456}24572458} // namespace bu_math24592460namespace basisu2461{2462class tracked_stat2463{2464public:2465tracked_stat() { clear(); }24662467inline void clear() { m_num = 0; m_total = 0; m_total2 = 0; }24682469inline void update(int32_t val) { m_num++; m_total += val; m_total2 += val * val; }24702471inline tracked_stat& operator += (uint32_t val) { update(val); return *this; }24722473inline uint32_t get_number_of_values() { return m_num; }2474inline uint64_t get_total() const { return m_total; }2475inline uint64_t get_total2() const { return m_total2; }24762477inline float get_average() const { return m_num ? (float)m_total / m_num : 0.0f; };2478inline float get_std_dev() const { return m_num ? sqrtf((float)(m_num * m_total2 - m_total * m_total)) / m_num : 0.0f; }2479inline float get_variance() const { float s = get_std_dev(); return s * s; }24802481private:2482uint32_t m_num;2483int64_t m_total;2484int64_t m_total2;2485};24862487class tracked_stat_dbl2488{2489public:2490tracked_stat_dbl() { clear(); }24912492inline void clear() { m_num = 0; m_total = 0; m_total2 = 0; }24932494inline void update(double val) { m_num++; m_total += val; m_total2 += val * val; }24952496inline tracked_stat_dbl& operator += (double val) { update(val); return *this; }24972498inline uint64_t get_number_of_values() { return m_num; }2499inline double get_total() const { return m_total; }2500inline double get_total2() const { return m_total2; }25012502inline double get_average() const { return m_num ? m_total / (double)m_num : 0.0f; };2503inline double get_std_dev() const { return m_num ? sqrt((double)(m_num * m_total2 - m_total * m_total)) / m_num : 0.0f; }2504inline double get_variance() const { double s = get_std_dev(); return s * s; }25052506private:2507uint64_t m_num;2508double m_total;2509double m_total2;2510};25112512template<typename FloatType>2513struct stats2514{2515uint32_t m_n;2516FloatType m_total, m_total_sq; // total, total of squares values2517FloatType m_avg, m_avg_sq; // mean, mean of the squared values2518FloatType m_rms; // sqrt(m_avg_sq)2519FloatType m_std_dev, m_var; // population standard deviation and variance2520FloatType m_mad; // mean absolute deviation2521FloatType m_min, m_max, m_range; // min and max values, and max-min2522FloatType m_len; // length of values as a vector (Euclidean norm or L2 norm)2523FloatType m_coeff_of_var; // coefficient of variation (std_dev/mean), High CV: Indicates greater variability relative to the mean, meaning the data values are more spread out,2524// Low CV : Indicates less variability relative to the mean, meaning the data values are more consistent.25252526FloatType m_skewness; // Skewness = 0: The data is perfectly symmetric around the mean,2527// Skewness > 0: The data is positively skewed (right-skewed),2528// Skewness < 0: The data is negatively skewed (left-skewed)2529// 0-.5 approx. symmetry, .5-1 moderate skew, >= 1 highly skewed25302531FloatType m_kurtosis; // Excess Kurtosis: Kurtosis = 0: The distribution has normal kurtosis (mesokurtic)2532// Kurtosis > 0: The distribution is leptokurtic, with heavy tails and a sharp peak2533// Kurtosis < 0: The distribution is platykurtic, with light tails and a flatter peak25342535bool m_any_zero;25362537FloatType m_median;2538uint32_t m_median_index;25392540stats()2541{2542clear();2543}25442545void clear()2546{2547m_n = 0;2548m_total = 0, m_total_sq = 0;2549m_avg = 0, m_avg_sq = 0;2550m_rms = 0;2551m_std_dev = 0, m_var = 0;2552m_mad = 0;2553m_min = BIG_FLOAT_VAL, m_max = -BIG_FLOAT_VAL; m_range = 0.0f;2554m_len = 0;2555m_coeff_of_var = 0;2556m_skewness = 0;2557m_kurtosis = 0;2558m_any_zero = false;25592560m_median = 0;2561m_median_index = 0;2562}25632564template<typename T>2565void calc_median(uint32_t n, const T* pVals, uint32_t stride = 1)2566{2567m_median = 0;2568m_median_index = 0;25692570if (!n)2571return;25722573basisu::vector< std::pair<T, uint32_t> > vals(n);25742575for (uint32_t i = 0; i < n; i++)2576{2577vals[i].first = pVals[i * stride];2578vals[i].second = i;2579}25802581std::sort(vals.begin(), vals.end(), [](const std::pair<T, uint32_t>& a, const std::pair<T, uint32_t>& b) {2582return a.first < b.first;2583});25842585m_median = vals[n / 2].first;2586if ((n & 1) == 0)2587m_median = (m_median + vals[(n / 2) - 1].first) * .5f;25882589m_median_index = vals[n / 2].second;2590}25912592template<typename T>2593void calc(uint32_t n, const T* pVals, uint32_t stride = 1, bool calc_median_flag = false)2594{2595clear();25962597if (!n)2598return;25992600if (calc_median_flag)2601calc_median(n, pVals, stride);26022603m_n = n;26042605for (uint32_t i = 0; i < n; i++)2606{2607FloatType v = (FloatType)pVals[i * stride];26082609if (v == 0.0f)2610m_any_zero = true;26112612m_total += v;2613m_total_sq += v * v;26142615if (!i)2616{2617m_min = v;2618m_max = v;2619}2620else2621{2622m_min = minimum(m_min, v);2623m_max = maximum(m_max, v);2624}2625}26262627m_range = m_max - m_min;26282629m_len = sqrt(m_total_sq);26302631const FloatType nd = (FloatType)n;26322633m_avg = m_total / nd;2634m_avg_sq = m_total_sq / nd;2635m_rms = sqrt(m_avg_sq);26362637for (uint32_t i = 0; i < n; i++)2638{2639FloatType v = (FloatType)pVals[i * stride];2640FloatType d = v - m_avg;26412642const FloatType d2 = d * d;2643const FloatType d3 = d2 * d;2644const FloatType d4 = d3 * d;26452646m_var += d2;2647m_mad += fabs(d);2648m_skewness += d3;2649m_kurtosis += d4;2650}26512652m_var /= nd;2653m_mad /= nd;26542655m_std_dev = sqrt(m_var);26562657m_coeff_of_var = (m_avg != 0.0f) ? (m_std_dev / fabs(m_avg)) : 0.0f;26582659FloatType k3 = m_std_dev * m_std_dev * m_std_dev;2660FloatType k4 = k3 * m_std_dev;2661m_skewness = (k3 != 0.0f) ? ((m_skewness / nd) / k3) : 0.0f;2662m_kurtosis = (k4 != 0.0f) ? (((m_kurtosis / nd) / k4) - 3.0f) : 0.0f;2663}26642665// Only compute average, variance and standard deviation.2666template<typename T>2667void calc_simplified(uint32_t n, const T* pVals, uint32_t stride = 1)2668{2669clear();26702671if (!n)2672return;26732674m_n = n;26752676for (uint32_t i = 0; i < n; i++)2677{2678FloatType v = (FloatType)pVals[i * stride];26792680m_total += v;2681}26822683const FloatType nd = (FloatType)n;26842685m_avg = m_total / nd;26862687for (uint32_t i = 0; i < n; i++)2688{2689FloatType v = (FloatType)pVals[i * stride];2690FloatType d = v - m_avg;26912692const FloatType d2 = d * d;26932694m_var += d2;2695}26962697m_var /= nd;2698m_std_dev = sqrt(m_var);2699}2700};27012702template<typename FloatType>2703struct comparative_stats2704{2705FloatType m_cov; // covariance2706FloatType m_pearson; // Pearson Correlation Coefficient (r) [-1,1]2707FloatType m_mse; // mean squared error2708FloatType m_rmse; // root mean squared error2709FloatType m_mae; // mean abs error2710FloatType m_rmsle; // root mean squared log error2711FloatType m_euclidean_dist; // euclidean distance between values as vectors2712FloatType m_cosine_sim; // normalized dot products of values as vectors2713FloatType m_min_diff, m_max_diff; // minimum/maximum abs difference between values27142715comparative_stats()2716{2717clear();2718}27192720void clear()2721{2722m_cov = 0;2723m_pearson = 0;2724m_mse = 0;2725m_rmse = 0;2726m_mae = 0;2727m_rmsle = 0;2728m_euclidean_dist = 0;2729m_cosine_sim = 0;2730m_min_diff = 0;2731m_max_diff = 0;2732}27332734template<typename T>2735void calc(uint32_t n, const T* pA, const T* pB, uint32_t a_stride = 1, uint32_t b_stride = 1, const stats<FloatType> *pA_stats = nullptr, const stats<FloatType> *pB_stats = nullptr)2736{2737clear();2738if (!n)2739return;27402741stats<FloatType> temp_a_stats;2742if (!pA_stats)2743{2744pA_stats = &temp_a_stats;2745temp_a_stats.calc(n, pA, a_stride);2746}27472748stats<FloatType> temp_b_stats;2749if (!pB_stats)2750{2751pB_stats = &temp_b_stats;2752temp_b_stats.calc(n, pB, b_stride);2753}27542755for (uint32_t i = 0; i < n; i++)2756{2757const FloatType fa = (FloatType)pA[i * a_stride];2758const FloatType fb = (FloatType)pB[i * b_stride];27592760if ((pA_stats->m_min >= 0.0f) && (pB_stats->m_min >= 0.0f))2761{2762const FloatType ld = log(fa + 1.0f) - log(fb + 1.0f);2763m_rmsle += ld * ld;2764}27652766const FloatType diff = fa - fb;2767const FloatType abs_diff = fabs(diff);27682769m_mse += diff * diff;2770m_mae += abs_diff;27712772m_min_diff = i ? minimum(m_min_diff, abs_diff) : abs_diff;2773m_max_diff = maximum(m_max_diff, abs_diff);27742775const FloatType da = fa - pA_stats->m_avg;2776const FloatType db = fb - pB_stats->m_avg;2777m_cov += da * db;27782779m_cosine_sim += fa * fb;2780}27812782const FloatType nd = (FloatType)n;27832784m_euclidean_dist = sqrt(m_mse);27852786m_mse /= nd;2787m_rmse = sqrt(m_mse);27882789m_mae /= nd;27902791m_cov /= nd;27922793FloatType dv = (pA_stats->m_std_dev * pB_stats->m_std_dev);2794if (dv != 0.0f)2795m_pearson = m_cov / dv;27962797if ((pA_stats->m_min >= 0.0) && (pB_stats->m_min >= 0.0f))2798m_rmsle = sqrt(m_rmsle / nd);27992800FloatType c = pA_stats->m_len * pB_stats->m_len;2801if (c != 0.0f)2802m_cosine_sim /= c;2803else2804m_cosine_sim = 0.0f;2805}28062807// Only computes Pearson, cov, mse, rmse, Euclidean distance2808template<typename T>2809void calc_pearson(uint32_t n, const T* pA, const T* pB, uint32_t a_stride = 1, uint32_t b_stride = 1, const stats<FloatType>* pA_stats = nullptr, const stats<FloatType>* pB_stats = nullptr)2810{2811clear();2812if (!n)2813return;28142815stats<FloatType> temp_a_stats;2816if (!pA_stats)2817{2818pA_stats = &temp_a_stats;2819temp_a_stats.calc(n, pA, a_stride);2820}28212822stats<FloatType> temp_b_stats;2823if (!pB_stats)2824{2825pB_stats = &temp_b_stats;2826temp_b_stats.calc(n, pB, b_stride);2827}28282829for (uint32_t i = 0; i < n; i++)2830{2831const FloatType fa = (FloatType)pA[i * a_stride];2832const FloatType fb = (FloatType)pB[i * b_stride];28332834const FloatType diff = fa - fb;28352836m_mse += diff * diff;28372838const FloatType da = fa - pA_stats->m_avg;2839const FloatType db = fb - pB_stats->m_avg;2840m_cov += da * db;2841}28422843const FloatType nd = (FloatType)n;28442845m_euclidean_dist = sqrt(m_mse);28462847m_mse /= nd;2848m_rmse = sqrt(m_mse);28492850m_cov /= nd;28512852FloatType dv = (pA_stats->m_std_dev * pB_stats->m_std_dev);2853if (dv != 0.0f)2854m_pearson = m_cov / dv;2855}28562857// Only computes MSE, RMSE, eclidiean distance, and covariance.2858template<typename T>2859void calc_simplified(uint32_t n, const T* pA, const T* pB, uint32_t a_stride = 1, uint32_t b_stride = 1, const stats<FloatType>* pA_stats = nullptr, const stats<FloatType>* pB_stats = nullptr)2860{2861clear();2862if (!n)2863return;28642865stats<FloatType> temp_a_stats;2866if (!pA_stats)2867{2868pA_stats = &temp_a_stats;2869temp_a_stats.calc(n, pA, a_stride);2870}28712872stats<FloatType> temp_b_stats;2873if (!pB_stats)2874{2875pB_stats = &temp_b_stats;2876temp_b_stats.calc(n, pB, b_stride);2877}28782879for (uint32_t i = 0; i < n; i++)2880{2881const FloatType fa = (FloatType)pA[i * a_stride];2882const FloatType fb = (FloatType)pB[i * b_stride];28832884const FloatType diff = fa - fb;28852886m_mse += diff * diff;28872888const FloatType da = fa - pA_stats->m_avg;2889const FloatType db = fb - pB_stats->m_avg;2890m_cov += da * db;2891}28922893const FloatType nd = (FloatType)n;28942895m_euclidean_dist = sqrt(m_mse);28962897m_mse /= nd;2898m_rmse = sqrt(m_mse);28992900m_cov /= nd;2901}29022903// Only computes covariance.2904template<typename T>2905void calc_cov(uint32_t n, const T* pA, const T* pB, uint32_t a_stride = 1, uint32_t b_stride = 1, const stats<FloatType>* pA_stats = nullptr, const stats<FloatType>* pB_stats = nullptr)2906{2907clear();2908if (!n)2909return;29102911stats<FloatType> temp_a_stats;2912if (!pA_stats)2913{2914pA_stats = &temp_a_stats;2915temp_a_stats.calc(n, pA, a_stride);2916}29172918stats<FloatType> temp_b_stats;2919if (!pB_stats)2920{2921pB_stats = &temp_b_stats;2922temp_b_stats.calc(n, pB, b_stride);2923}29242925for (uint32_t i = 0; i < n; i++)2926{2927const FloatType fa = (FloatType)pA[i * a_stride];2928const FloatType fb = (FloatType)pB[i * b_stride];29292930const FloatType da = fa - pA_stats->m_avg;2931const FloatType db = fb - pB_stats->m_avg;2932m_cov += da * db;2933}29342935const FloatType nd = (FloatType)n;29362937m_cov /= nd;2938}2939};29402941class stat_history2942{2943public:2944stat_history(uint32_t size)2945{2946init(size);2947}29482949void init(uint32_t size)2950{2951clear();29522953m_samples.reserve(size);2954m_samples.resize(0);2955m_max_samples = size;2956}29572958inline void clear()2959{2960m_samples.resize(0);2961m_max_samples = 0;2962}29632964inline void update(double val)2965{2966m_samples.push_back(val);29672968if (m_samples.size() > m_max_samples)2969m_samples.erase_index(0);2970}29712972inline size_t size()2973{2974return m_samples.size();2975}29762977struct stats2978{2979double m_avg = 0;2980double m_std_dev = 0;2981double m_var = 0;2982double m_mad = 0;2983double m_min_val = 0;2984double m_max_val = 0;29852986void clear()2987{2988basisu::clear_obj(*this);2989}2990};29912992inline void get_stats(stats& s)2993{2994s.clear();29952996if (m_samples.empty())2997return;29982999double total = 0, total2 = 0;30003001for (size_t i = 0; i < m_samples.size(); i++)3002{3003const double v = m_samples[i];30043005total += v;3006total2 += v * v;30073008if (!i)3009{3010s.m_min_val = v;3011s.m_max_val = v;3012}3013else3014{3015s.m_min_val = basisu::minimum<double>(s.m_min_val, v);3016s.m_max_val = basisu::maximum<double>(s.m_max_val, v);3017}3018}30193020const double n = (double)m_samples.size();30213022s.m_avg = total / n;3023s.m_std_dev = sqrt((n * total2 - total * total)) / n;3024s.m_var = (n * total2 - total * total) / (n * n);30253026double sc = 0;3027for (size_t i = 0; i < m_samples.size(); i++)3028{3029const double v = m_samples[i];3030s.m_mad += fabs(v - s.m_avg);30313032sc += basisu::square(v - s.m_avg);3033}3034sc = sqrt(sc / n);30353036s.m_mad /= n;3037}30383039private:3040uint32_t m_max_samples;3041basisu::vector<double> m_samples;3042};30433044// bfloat16 helpers, see:3045// https://en.wikipedia.org/wiki/Bfloat16_floating-point_format30463047typedef union3048{3049uint32_t u;3050float f;3051} float32_union;30523053typedef uint16_t bfloat16;30543055inline float bfloat16_to_float(bfloat16 bfloat16)3056{3057float32_union float_union;3058float_union.u = ((uint32_t)bfloat16) << 16;3059return float_union.f;3060}30613062inline bfloat16 float_to_bfloat16(float input, bool round_flag = true)3063{3064float32_union float_union;3065float_union.f = input;30663067uint32_t exponent = (float_union.u >> 23) & 0xFF;30683069// Check if the number is denormalized in float32 (exponent == 0)3070if (exponent == 0)3071{3072// Handle denormalized float32 as zero in bfloat163073return 0x0000;3074}30753076// Extract the top 16 bits (sign, exponent, and 7 most significant bits of the mantissa)3077uint32_t upperBits = float_union.u >> 16;30783079if (round_flag)3080{3081// Check the most significant bit of the lower 16 bits for rounding3082uint32_t lowerBits = float_union.u & 0xFFFF;30833084// Round to nearest or even3085if ((lowerBits & 0x8000) &&3086((lowerBits > 0x8000) || ((lowerBits == 0x8000) && (upperBits & 1)))3087)3088{3089// Round up3090upperBits += 1;30913092// Check for overflow in the exponent after rounding up3093if (((upperBits & 0x7F80) == 0x7F80) && ((upperBits & 0x007F) == 0))3094{3095// Exponent overflow (the upper bits became all 1s)3096// Set the result to infinity3097upperBits = (upperBits & 0x8000) | 0x7F80; // Preserve the sign bit, set exponent to 0xFF, and mantissa to 03098}3099}3100}31013102return (bfloat16)upperBits;3103}31043105inline int bfloat16_get_exp(bfloat16 v)3106{3107return (int)((v >> 7) & 0xFF) - 127;3108}31093110inline int bfloat16_get_mantissa(bfloat16 v)3111{3112return (v & 0x7F);3113}31143115inline int bfloat16_get_sign(bfloat16 v)3116{3117return (v & 0x8000) ? -1 : 1;3118}31193120inline bool bfloat16_is_nan_or_inf(bfloat16 v)3121{3122return ((v >> 7) & 0xFF) == 0xFF;3123}31243125inline bool bfloat16_is_zero(bfloat16 v)3126{3127return (v & 0x7FFF) == 0;3128}31293130inline bfloat16 bfloat16_init(int sign, int exp, int mant)3131{3132uint16_t res = (sign < 0) ? 0x8000 : 0;31333134assert((exp >= -126) && (res <= 127));3135res |= ((exp + 127) << 7);31363137assert((mant >= 0) && (mant < 128));3138res |= mant;31393140return res;3141}314231433144} // namespace basisu3145314631473148